Compare commits

...

116 Commits

Author SHA1 Message Date
Elias Schneider
5dcf69e974 release: 0.45.0 2025-03-30 00:12:19 +01:00
Alessandro (Ale) Segala
519d58d88c fix: use WAL for SQLite by default and set busy_timeout (#388)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-03-29 23:12:48 +01:00
Alessandro (Ale) Segala
b3b43a56af refactor: do not include test controller in production builds (#402)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-29 22:11:25 +00:00
Elias Schneider
fc68cf7eb2 chore(translations): add Brazilian Portuguese 2025-03-29 23:03:18 +01:00
Elias Schneider
8ca7873802 chore(translations): update translations via Crowdin (#394)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-03-29 22:59:24 +01:00
Elias Schneider
591bf841f5 Merge remote-tracking branch 'origin/main' 2025-03-29 22:56:04 +01:00
Kyle Mendell
8f8884d208 refactor: add swagger title and version info (#399) 2025-03-29 21:55:47 +00:00
Elias Schneider
7e658276f0 fix: ldap users aren't deleted if removed from ldap server 2025-03-29 22:55:44 +01:00
Gutyina Gergő
583a1f8fee chore(deps): install inlang plugins from npm (#401)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-03-29 22:50:51 +01:00
Rich
b935a4824a ci/cd: migrate backend linter to v2. fixed unit test workflow (#400) 2025-03-28 04:00:55 -05:00
Elias Schneider
cbd1bbdf74 fix: use value receiver for AuditLogData 2025-03-27 22:41:19 +01:00
Alessandro (Ale) Segala
96876a99c5 feat: add support for ECDSA and EdDSA keys (#359)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-27 18:20:39 +01:00
Elias Schneider
5c198c280c refactor: fix code smells 2025-03-27 17:46:10 +01:00
Elias Schneider
c9e0073b63 refactor: fix code smells 2025-03-27 16:48:36 +01:00
Elias Schneider
6fa26c97be ci/cd: run linter only on backend changes 2025-03-27 16:18:15 +01:00
Elias Schneider
6746dbf41e chore(translations): update translations via Crowdin (#386) 2025-03-27 15:15:22 +00:00
Rich
4ac1196d8d ci/cd: add basic static analysis for backend (#389) 2025-03-27 16:13:56 +01:00
Sam
4d049bbe24 docs: update .env.example to reflect the new documentation location (#385) 2025-03-25 21:53:23 +00:00
Elias Schneider
664a1cf8ef release: 0.44.0 2025-03-25 17:09:06 +01:00
Elias Schneider
e6f50191cf fix: stop container if Caddy, the frontend or the backend fails 2025-03-25 16:40:53 +01:00
dependabot[bot]
de9a3cce03 chore(deps-dev): bump vite from 6.2.1 to 6.2.3 in /frontend in the npm_and_yarn group across 1 directory (#384)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-25 09:52:15 -05:00
Alessandro (Ale) Segala
8c963818bb fix: hash the refresh token in the DB (security) (#379) 2025-03-25 15:36:53 +01:00
Alessandro (Ale) Segala
26b2de4f00 refactor: use atomic renames for uploaded files (#372)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-23 20:21:44 +00:00
Kyle Mendell
b8dcda8049 feat: add OIDC refresh_token support (#325)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-23 20:14:26 +00:00
Kyle Mendell
7888d70656 docs: fix api routers for swag documentation (#378) 2025-03-23 19:26:07 +00:00
Elias Schneider
35766af055 chore(translations): add French, Czech and German to language picker 2025-03-23 20:13:58 +01:00
Elias Schneider
c53de25d25 chore(translations): update translations via Crowdin (#375) 2025-03-23 19:09:34 +00:00
Kyle Mendell
cdfe8161d4 fix: skip ldap objects without a valid unique id (#376)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-23 18:30:12 +00:00
dependabot[bot]
e2f74e5687 chore(deps): bump github.com/golang-jwt/jwt/v5 from 5.2.1 to 5.2.2 in /backend in the go_modules group across 1 directory (#374)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-21 17:45:52 -05:00
Elias Schneider
132efd675c chore(translations): update translations via Crowdin (#368) 2025-03-21 21:32:28 +00:00
Elias Schneider
1167454c4f Merge branch 'main' of https://github.com/pocket-id/pocket-id 2025-03-21 22:30:40 +01:00
Elias Schneider
af5b2f7913 ci/cd: skip e2e tests if the PR comes from i18n_crowdin 2025-03-21 22:30:37 +01:00
Savely Krasovsky
bc4af846e1 chore(translations): add Russian localization (#371)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-21 21:24:55 +00:00
Elias Schneider
edf1097dd3 ci/cd: fix invalid action configuration 2025-03-21 22:20:05 +01:00
Elias Schneider
eb34535c5a release: 0.43.1 2025-03-20 21:38:02 +01:00
Elias Schneider
3120ebf239 fix: wrong base locale causes crash 2025-03-20 21:36:05 +01:00
Elias Schneider
2fb41937ca ci/cd: ignore e2e tests on Crowdin branch 2025-03-20 20:49:17 +01:00
Elias Schneider
d78a1c6974 release: 0.43.0 2025-03-20 20:47:17 +01:00
Elias Schneider
c578baba95 chore: add language request issue template 2025-03-20 20:38:33 +01:00
Elias Schneider
bb23194e88 chore(translations): remove unused messages 2025-03-20 20:26:43 +01:00
Elias Schneider
31ac56004a refactor: use language code with country for messages 2025-03-20 20:15:26 +01:00
Elias Schneider
d59ec01b33 Update Crowdin configuration file 2025-03-20 20:12:48 +01:00
Elias Schneider
3ee26a2cfb chore: update Crowdin configuration 2025-03-20 20:09:05 +01:00
Elias Schneider
39395c79c3 Update Crowdin configuration file 2025-03-20 20:08:24 +01:00
Jonas Claes
269b5a3c92 feat: add support for translations (#349)
Co-authored-by: Kyle Mendell <kmendell@outlook.com>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-20 18:57:41 +00:00
Kyle Mendell
041c565dc1 feat(passkeys): name new passkeys based on agguids (#332)
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-20 15:35:08 +00:00
Elias Schneider
e486dbd771 release: 0.42.1 2025-03-18 23:03:50 +01:00
Elias Schneider
f7e36a422e fix: kid not added to JWTs 2025-03-18 23:03:34 +01:00
Elias Schneider
f74c7bf95d release: 0.42.0 2025-03-18 21:11:19 +01:00
Alessandro (Ale) Segala
a7c9741802 feat: store keys as JWK on disk (#339)
Co-authored-by: Kyle Mendell <kmendell@ofkm.us>
2025-03-18 21:08:33 +01:00
Elias Schneider
e9b2d981b7 release: 0.41.0 2025-03-18 21:04:53 +01:00
Kyle Mendell
8f146188d5 feat(profile-picture): allow reset of profile picture (#355)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-18 19:59:31 +00:00
Viktor Szépe
a0f93bda49 chor: correct misspellings (#352) 2025-03-18 12:54:39 +01:00
Savely Krasovsky
0423d354f5 fix: own avatar not loading (#351) 2025-03-18 12:02:59 +01:00
Elias Schneider
9245851126 release: 0.40.1 2025-03-16 18:02:49 +01:00
Alexander Lehmann
39b7f6678c fix: emails are considered as medium spam by rspamd (#337) 2025-03-16 17:46:45 +01:00
Elias Schneider
e45d9e970d fix: caching for own profile picture 2025-03-16 17:45:30 +01:00
Elias Schneider
8ead0be8cd fix: API keys not working if sqlite is used 2025-03-16 14:28:44 +01:00
Elias Schneider
9f28503d6c fix: remove custom claim key restrictions 2025-03-16 14:11:33 +01:00
Elias Schneider
26e05947fe ci/cd: add separate worfklow for unit tests 2025-03-16 13:08:56 +01:00
Alessandro (Ale) Segala
348192b9d7 fix: Fixes and performance improvements in utils package (#331) 2025-03-14 19:21:24 -05:00
Kyle Mendell
b483e2e92f fix: email logo icon displaying too big (#336) 2025-03-14 13:38:27 -05:00
Elias Schneider
42f55e6e54 release: 0.40.0 2025-03-13 20:49:48 +01:00
Elias Schneider
a4bfd08a0f chore: automatically detect release type in release script 2025-03-13 20:49:33 +01:00
Alessandro (Ale) Segala
7b654c6bd1 feat: allow setting path where keys are stored (#327) 2025-03-13 17:01:15 +01:00
Elias Schneider
8c1c04db1d Merge branch 'main' of https://github.com/pocket-id/pocket-id 2025-03-13 14:18:54 +01:00
Elias Schneider
ec4b41a1d2 fix(docker): missing write permissions on scripts 2025-03-13 14:18:48 +01:00
dependabot[bot]
d27a121985 chore(deps): bump @babel/runtime from 7.26.7 to 7.26.10 in /frontend in the npm_and_yarn group across 1 directory (#328)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-13 14:15:49 +01:00
dependabot[bot]
d8952c0d62 chore(deps): bump golang.org/x/net from 0.34.0 to 0.36.0 in /backend in the go_modules group across 1 directory (#326)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-13 14:06:43 +01:00
Nebula
f65997e85b chore: add Dev Container (#313) 2025-03-11 17:24:41 -05:00
Elias Schneider
90f8068053 release: 0.39.0 2025-03-11 20:59:15 +01:00
Elias Schneider
9ef2ddf796 fix: alternative login method link on mobile 2025-03-11 20:58:30 +01:00
Elias Schneider
d1b9f3a44e refactor: adapt api key list to new sort behavior 2025-03-11 20:22:56 +01:00
Kyle Mendell
62915d863a feat: api key authentication (#291)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-11 19:16:42 +00:00
Elias Schneider
74ba8390f4 release: 0.38.0 2025-03-10 20:52:35 +01:00
Elias Schneider
31198feec2 feat: add env variable to disable update check 2025-03-10 20:48:57 +01:00
Elias Schneider
e5ec264bfd fix: redirection not correctly if signing in with email code 2025-03-10 20:36:52 +01:00
Kot C
c822192124 fix: typo in account settings (#307) 2025-03-10 13:35:46 +00:00
Elias Schneider
f2d61e964c release: 0.37.0 2025-03-10 14:09:30 +01:00
dependabot[bot]
f1256322b6 chore(deps): bump the npm_and_yarn group across 1 directory with 3 updates (#306)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-10 14:06:13 +01:00
Elias Schneider
7885ae011c tests: fix user group assignment test 2025-03-10 14:05:51 +01:00
Elias Schneider
6a8dd84ca9 fix: add back setup page 2025-03-10 13:00:08 +01:00
Jonas
eb1426ed26 feat(account): add ability to sign in with login code (#271)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-10 12:45:45 +01:00
Elias Schneider
a9713cf6a1 feat: increase default item count per page 2025-03-10 12:39:42 +01:00
Elias Schneider
8e344f1151 fix: make sorting consistent around tables 2025-03-10 12:37:16 +01:00
Elias Schneider
04efc36115 fix: add timeout to update check 2025-03-10 09:41:58 +01:00
Elias Schneider
2ee0bad2c0 docs: add Discord contact link to issue template 2025-03-07 14:25:19 +01:00
Elias Schneider
d0da532240 refactor: fix type errors 2025-03-07 13:56:24 +01:00
Elias Schneider
8d55c7c393 release: 0.36.0 2025-03-06 22:25:25 +01:00
Kyle Mendell
0f14a93e1d feat: display groups on the account page (#296)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-06 22:25:03 +01:00
Elias Schneider
37b24bed91 ci/cd: remove PR docker build action 2025-03-06 22:24:00 +01:00
Elias Schneider
66090f36a8 ci/cd: use github.repository variable intead of hardcoding the repository name 2025-03-06 19:13:44 +01:00
Kyle Mendell
ff34e3b925 fix: default sorting on tables (#299)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-06 17:42:31 +01:00
Savely Krasovsky
91f254c7bb feat: enable sd_notify support (#277) 2025-03-06 17:42:12 +01:00
Kyle Mendell
85db96b0ef ci/cd: add pr docker build (#293)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-03-06 16:29:33 +01:00
Elias Schneider
12d60fea23 release: 0.35.6 2025-03-03 16:49:55 +01:00
Elias Schneider
2d733fc79f fix: support LOGIN authentication method for SMTP (#292) 2025-03-03 16:48:38 +01:00
Elias Schneider
a421d01e0c release: 0.35.5 2025-03-03 16:48:07 +01:00
Elias Schneider
1026ee4f5b fix: profile picture orientation if image is rotated with EXIF 2025-03-03 09:06:52 +01:00
Elias Schneider
cddfe8fa4c release: 0.35.4 2025-03-01 20:42:53 +01:00
Jonas
ef25f6b6b8 fix: profile picture of other user can't be updated (#273) 2025-03-01 20:42:29 +01:00
Elias Schneider
1652cc65f3 fix: support POST for OIDC userinfo endpoint 2025-03-01 20:42:00 +01:00
Elias Schneider
4bafee4f58 fix: add groups scope and claim to well known endpoint 2025-03-01 20:41:30 +01:00
Elias Schneider
e46471cc2d release: 0.35.3 2025-02-25 20:34:37 +01:00
Elias Schneider
fde951b543 fix(ldap): sync error if LDAP user collides with an existing user 2025-02-25 20:34:13 +01:00
Kyle Mendell
01a9de0b04 fix: add option to manually select SMTP TLS method (#268)
Co-authored-by: Elias Schneider <login@eliasschneider.com>
2025-02-25 19:10:20 +01:00
Elias Schneider
a1131bca9a release: 0.35.2 2025-02-24 09:40:48 +01:00
Elias Schneider
9a167d4076 fix: delete profile picture if user gets deleted 2025-02-24 09:40:14 +01:00
Elias Schneider
887c5e462a fix: updating profile picture of other user updates own profile picture 2025-02-24 09:35:44 +01:00
Elias Schneider
20eba1378e release: 0.35.1 2025-02-22 14:59:43 +01:00
Elias Schneider
a6ae7ae287 fix: add validation that PUBLIC_APP_URL can't contain a path 2025-02-22 14:59:10 +01:00
Elias Schneider
840a672fc3 fix: binary profile picture can't be imported from LDAP 2025-02-22 14:51:21 +01:00
Elias Schneider
7446f853fc release: 0.35.0 2025-02-19 14:29:24 +01:00
Elias Schneider
652ee6ad5d feat: add ability to upload a profile picture (#244) 2025-02-19 14:28:45 +01:00
Elias Schneider
dca9e7a11a fix: emails do not get rendered correctly in Gmail 2025-02-19 13:54:36 +01:00
Elias Schneider
816c198a42 fix: app config strings starting with a number are parsed incorrectly 2025-02-18 21:36:08 +01:00
254 changed files with 11878 additions and 1948 deletions

View File

@@ -0,0 +1,32 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/typescript-node
{
"name": "pocket-id",
// Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile
"image": "mcr.microsoft.com/devcontainers/typescript-node:1-22-bookworm",
"features": {
"ghcr.io/devcontainers/features/go:1": {},
"ghcr.io/devcontainers-extra/features/caddy:1": {}
},
"customizations": {
"vscode": {
"extensions": [
"golang.go",
"svelte.svelte-vscode"
]
}
},
// Use 'postCreateCommand' to run commands after the container is created.
// Install npm dependencies for the frontend.
"postCreateCommand": "npm install --prefix frontend"
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Configure tool-specific properties.
// "customizations": {},
// Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "root"
}

View File

@@ -1,4 +1,4 @@
# See the README for more information: https://github.com/pocket-id/pocket-id?tab=readme-ov-file#environment-variables
# See the documentation for more information: https://pocket-id.org/docs/configuration/environment-variables
PUBLIC_APP_URL=http://localhost
TRUST_PROXY=false
MAXMIND_LICENSE_KEY=

View File

@@ -49,7 +49,7 @@ body:
required: false
attributes:
label: "Log Output"
description: "Output of log files when the issue occured to help us diagnose the issue."
description: "Output of log files when the issue occurred to help us diagnose the issue."
- type: markdown
attributes:
value: |

View File

@@ -1 +1,5 @@
blank_issues_enabled: false
blank_issues_enabled: false
contact_links:
- name: 💬 Discord
url: https://discord.gg/8wudU9KaxM
about: For help and chatting with the community

View File

@@ -0,0 +1,20 @@
name: "🌐 Language request"
description: "You want to contribute to a language that isn't on Crowdin yet?"
title: "🌐 Language Request: <language name in english>"
labels: [language-request]
body:
- type: input
id: language-name-native
attributes:
label: "🌐 Language Name (native)"
placeholder: "Schweizerdeutsch"
validations:
required: true
- type: input
id: language-code
attributes:
label: "🌐 ISO 639-1 Language Code"
description: "You can find your language code [here](https://www.andiamo.co.uk/resources/iso-language-codes/)."
placeholder: "de-CH"
validations:
required: true

12
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,12 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for more information:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
# https://containers.dev/guide/dependabot
version: 2
updates:
- package-ecosystem: "devcontainers"
directory: "/"
schedule:
interval: weekly

39
.github/workflows/backend-linter.yml vendored Normal file
View File

@@ -0,0 +1,39 @@
name: Run Backend Linter
on:
push:
branches: [main]
paths:
- "backend/**"
pull_request:
branches: [main]
paths:
- "backend/**"
permissions:
# Required: allow read access to the content for analysis.
contents: read
# Optional: allow read access to pull request. Use with `only-new-issues` option.
pull-requests: read
# Optional: allow write access to checks to allow the action to annotate code in the PR.
checks: write
jobs:
golangci-lint:
name: Run Golangci-lint
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version-file: backend/go.mod
- name: Run Golangci-lint
uses: golangci/golangci-lint-action@dec74fa03096ff515422f71d18d41307cacde373 # v7.0.0
with:
version: v2.0.2
working-directory: backend
only-new-issues: ${{ github.event_name == 'pull_request' }}

View File

@@ -30,11 +30,6 @@ jobs:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_REGISTRY_USERNAME }}
password: ${{ secrets.DOCKER_REGISTRY_PASSWORD }}
- name: 'Login to GitHub Container Registry'
uses: docker/login-action@v3

View File

@@ -15,6 +15,7 @@ on:
jobs:
build:
if: github.event.pull_request.head.ref != 'i18n_crowdin'
timeout-minutes: 20
runs-on: ubuntu-latest
steps:
@@ -26,6 +27,7 @@ jobs:
with:
tags: pocket-id/pocket-id:test
outputs: type=docker,dest=/tmp/docker-image.tar
build-args: BUILD_TAGS=e2etest
- name: Upload Docker image artifact
uses: actions/upload-artifact@v4
@@ -34,6 +36,7 @@ jobs:
path: /tmp/docker-image.tar
test-sqlite:
if: github.event.pull_request.head.ref != 'i18n_crowdin'
runs-on: ubuntu-latest
needs: build
steps:
@@ -49,6 +52,7 @@ jobs:
with:
name: docker-image
path: /tmp
- name: Load Docker Image
run: docker load -i /tmp/docker-image.tar
@@ -67,6 +71,8 @@ jobs:
-e APP_ENV=test \
pocket-id/pocket-id:test
docker logs -f pocket-id-sqlite &> /tmp/backend.log &
- name: Run Playwright tests
working-directory: ./frontend
run: npx playwright test
@@ -79,7 +85,16 @@ jobs:
include-hidden-files: true
retention-days: 15
- uses: actions/upload-artifact@v4
if: always()
with:
name: backend-sqlite
path: /tmp/backend.log
include-hidden-files: true
retention-days: 15
test-postgres:
if: github.event.pull_request.head.ref != 'i18n_crowdin'
runs-on: ubuntu-latest
needs: build
steps:
@@ -137,17 +152,27 @@ jobs:
-p 80:80 \
-e APP_ENV=test \
-e DB_PROVIDER=postgres \
-e POSTGRES_CONNECTION_STRING=postgresql://postgres:postgres@pocket-id-db:5432/pocket-id \
-e DB_CONNECTION_STRING=postgresql://postgres:postgres@pocket-id-db:5432/pocket-id \
pocket-id/pocket-id:test
docker logs -f pocket-id-postgres &> /tmp/backend.log &
- name: Run Playwright tests
working-directory: ./frontend
run: npx playwright test
- uses: actions/upload-artifact@v4
if: always()
if: always() && github.event.pull_request.head.ref != 'i18n_crowdin'
with:
name: playwright-report-postgres
path: frontend/tests/.report
include-hidden-files: true
retention-days: 15
- uses: actions/upload-artifact@v4
if: always()
with:
name: backend-postgres
path: /tmp/backend.log
include-hidden-files: true
retention-days: 15

35
.github/workflows/unit-tests.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
name: Unit Tests
on:
push:
branches: [main]
paths:
- "backend/**"
pull_request:
branches: [main]
paths:
- "backend/**"
jobs:
test-backend:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: 'backend/go.mod'
cache-dependency-path: 'backend/go.sum'
- name: Install dependencies
working-directory: backend
run: |
go get ./...
- name: Run backend unit tests
working-directory: backend
run: |
set -e -o pipefail
go test -v ./... | tee /tmp/TestResults.log
- uses: actions/upload-artifact@v4
if: always()
with:
name: backend-unit-tests
path: /tmp/TestResults.log
retention-days: 15

34
.github/workflows/update-aaguids.yml vendored Normal file
View File

@@ -0,0 +1,34 @@
name: Update AAGUIDs
on:
schedule:
- cron: "0 0 * * 1" # Runs every Monday at midnight
workflow_dispatch: # Allows manual triggering of the workflow
permissions:
contents: write
jobs:
update-aaguids:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Fetch JSON data
run: |
curl -o data.json https://raw.githubusercontent.com/pocket-id/passkey-aaguids/refs/heads/main/combined_aaguid.json
- name: Process JSON data
run: |
mkdir -p backend/resources
jq -c 'map_values(.name)' data.json > backend/resources/aaguids.json
- name: Commit changes
run: |
git config --local user.email "action@github.com"
git config --local user.name "GitHub Action"
git add backend/resources/aaguids.json
git diff --staged --quiet || git commit -m "chore: update AAGUIDs"
git push

3
.gitignore vendored
View File

@@ -48,3 +48,6 @@ pocket-id-backend
npm-debug.log*
yarn-debug.log*
yarn-error.log*
#Debug
backend/cmd/__debug_*

View File

@@ -1 +1 @@
0.34.0
0.45.0

5
.vscode/extensions.json vendored Normal file
View File

@@ -0,0 +1,5 @@
{
"recommendations": [
"inlang.vs-code-extension"
]
}

42
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,42 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Backend",
"type": "go",
"request": "launch",
"envFile": "${workspaceFolder}/backend/.env.example",
"env": {
"APP_ENV": "development"
},
"mode": "debug",
"program": "${workspaceFolder}/backend/cmd/main.go",
},
{
"name": "Frontend",
"type": "node",
"request": "launch",
"envFile": "${workspaceFolder}/frontend/.env.example",
"cwd": "${workspaceFolder}/frontend",
"runtimeExecutable": "npm",
"runtimeArgs": [
"run",
"dev"
]
}
],
"compounds": [
{
"name": "Development",
"configurations": [
"Backend",
"Frontend"
],
"presentation": {
"hidden": false,
"group": "",
"order": 1
}
}
],
}

3
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"go.buildTags": "e2etest"
}

37
.vscode/tasks.json vendored Normal file
View File

@@ -0,0 +1,37 @@
{
// See https://go.microsoft.com/fwlink/?LinkId=733558
// for the documentation about the tasks.json format
"version": "2.0.0",
"tasks": [
{
"label": "Run Caddy",
"type": "shell",
"command": "caddy run --config reverse-proxy/Caddyfile",
"isBackground": true,
"problemMatcher": {
"owner": "custom",
"pattern": [
{
"regexp": ".",
"file": 1,
"location": 2,
"message": 3
}
],
"background": {
"activeOnStart": true,
"beginsPattern": ".*",
"endsPattern": "Caddyfile.*"
}
},
"presentation": {
"reveal": "always",
"panel": "new"
},
"runOptions": {
"runOn": "folderOpen",
"instanceLimit": 1
}
}
]
}

View File

@@ -1,3 +1,209 @@
## [](https://github.com/pocket-id/pocket-id/compare/v0.44.0...v) (2025-03-29)
### Features
* add support for ECDSA and EdDSA keys ([#359](https://github.com/pocket-id/pocket-id/issues/359)) ([96876a9](https://github.com/pocket-id/pocket-id/commit/96876a99c586508b72c27669ab200ff6a29db771))
### Bug Fixes
* ldap users aren't deleted if removed from ldap server ([7e65827](https://github.com/pocket-id/pocket-id/commit/7e658276f04d08a1f5117796e55d45e310204dab))
* use value receiver for `AuditLogData` ([cbd1bbd](https://github.com/pocket-id/pocket-id/commit/cbd1bbdf741eedd03e93598d67623c75c74b6212))
* use WAL for SQLite by default and set busy_timeout ([#388](https://github.com/pocket-id/pocket-id/issues/388)) ([519d58d](https://github.com/pocket-id/pocket-id/commit/519d58d88c906abc5139e35933bdeba0396c10a2))
## [](https://github.com/pocket-id/pocket-id/compare/v0.43.1...v) (2025-03-25)
### Features
* add OIDC refresh_token support ([#325](https://github.com/pocket-id/pocket-id/issues/325)) ([b8dcda8](https://github.com/pocket-id/pocket-id/commit/b8dcda80497e554d163a370eff81fe000f8831f4))
### Bug Fixes
* hash the refresh token in the DB (security) ([#379](https://github.com/pocket-id/pocket-id/issues/379)) ([8c96381](https://github.com/pocket-id/pocket-id/commit/8c963818bb90c84dac04018eec93790900d4b0ce))
* skip ldap objects without a valid unique id ([#376](https://github.com/pocket-id/pocket-id/issues/376)) ([cdfe816](https://github.com/pocket-id/pocket-id/commit/cdfe8161d4429bdfe879887fe0b563a67c14f50b))
* stop container if Caddy, the frontend or the backend fails ([e6f5019](https://github.com/pocket-id/pocket-id/commit/e6f50191cf05a5d0ac0e0000cf66423646f1920e))
## [](https://github.com/pocket-id/pocket-id/compare/v0.43.0...v) (2025-03-20)
### Bug Fixes
* wrong base locale causes crash ([3120ebf](https://github.com/pocket-id/pocket-id/commit/3120ebf239b90f0bc0a0af33f30622e034782398))
## [](https://github.com/pocket-id/pocket-id/compare/v0.42.1...v) (2025-03-20)
### Features
* add support for translations ([#349](https://github.com/pocket-id/pocket-id/issues/349)) ([269b5a3](https://github.com/pocket-id/pocket-id/commit/269b5a3c9249bb8081c74741141d3d5a69ea42a2))
* **passkeys:** name new passkeys based on agguids ([#332](https://github.com/pocket-id/pocket-id/issues/332)) ([041c565](https://github.com/pocket-id/pocket-id/commit/041c565dc10f15edb3e8ab58e9a4df5e48a2a6d3))
## [](https://github.com/pocket-id/pocket-id/compare/v0.42.0...v) (2025-03-18)
### Bug Fixes
* kid not added to JWTs ([f7e36a4](https://github.com/pocket-id/pocket-id/commit/f7e36a422ea6b5327360c9a13308ae408ff7fffe))
## [](https://github.com/pocket-id/pocket-id/compare/v0.41.0...v) (2025-03-18)
### Features
* store keys as JWK on disk ([#339](https://github.com/pocket-id/pocket-id/issues/339)) ([a7c9741](https://github.com/pocket-id/pocket-id/commit/a7c9741802667811c530ef4e6313b71615ec6a9b))
## [](https://github.com/pocket-id/pocket-id/compare/v0.40.1...v) (2025-03-18)
### Features
* **profile-picture:** allow reset of profile picture ([#355](https://github.com/pocket-id/pocket-id/issues/355)) ([8f14618](https://github.com/pocket-id/pocket-id/commit/8f146188d57b5c08a4c6204674c15379232280d8))
### Bug Fixes
* own avatar not loading ([#351](https://github.com/pocket-id/pocket-id/issues/351)) ([0423d35](https://github.com/pocket-id/pocket-id/commit/0423d354f533d2ff4fd431859af3eea7d4d7044f))
## [](https://github.com/pocket-id/pocket-id/compare/v0.40.0...v) (2025-03-16)
### Bug Fixes
* API keys not working if sqlite is used ([8ead0be](https://github.com/pocket-id/pocket-id/commit/8ead0be8cd0cfb542fe488b7251cfd5274975ae1))
* caching for own profile picture ([e45d9e9](https://github.com/pocket-id/pocket-id/commit/e45d9e970d327a5120ff9fb0c8d42df8af69bb38))
* email logo icon displaying too big ([#336](https://github.com/pocket-id/pocket-id/issues/336)) ([b483e2e](https://github.com/pocket-id/pocket-id/commit/b483e2e92fdb528e7de026350a727d6970227426))
* emails are considered as medium spam by rspamd ([#337](https://github.com/pocket-id/pocket-id/issues/337)) ([39b7f66](https://github.com/pocket-id/pocket-id/commit/39b7f6678c98cadcdc3abfbcb447d8eb0daa9eb0))
* Fixes and performance improvements in utils package ([#331](https://github.com/pocket-id/pocket-id/issues/331)) ([348192b](https://github.com/pocket-id/pocket-id/commit/348192b9d7e2698add97810f8fba53d13d0df018))
* remove custom claim key restrictions ([9f28503](https://github.com/pocket-id/pocket-id/commit/9f28503d6c73d3521d1309bee055704a0507e9b5))
## [](https://github.com/pocket-id/pocket-id/compare/v0.39.0...v) (2025-03-13)
### Features
* allow setting path where keys are stored ([#327](https://github.com/pocket-id/pocket-id/issues/327)) ([7b654c6](https://github.com/pocket-id/pocket-id/commit/7b654c6bd111ddcddd5e3450cbf326d9cf1777b6))
### Bug Fixes
* **docker:** missing write permissions on scripts ([ec4b41a](https://github.com/pocket-id/pocket-id/commit/ec4b41a1d26ea00bb4a95f654ac4cc745b2ce2e8))
## [](https://github.com/pocket-id/pocket-id/compare/v0.38.0...v) (2025-03-11)
### Features
* api key authentication ([#291](https://github.com/pocket-id/pocket-id/issues/291)) ([62915d8](https://github.com/pocket-id/pocket-id/commit/62915d863a4adc09cf467b75c414a045be43c2bb))
### Bug Fixes
* alternative login method link on mobile ([9ef2ddf](https://github.com/pocket-id/pocket-id/commit/9ef2ddf7963c6959992f3a5d6816840534e926e9))
## [](https://github.com/pocket-id/pocket-id/compare/v0.37.0...v) (2025-03-10)
### Features
* add env variable to disable update check ([31198fe](https://github.com/pocket-id/pocket-id/commit/31198feec2ae77dd6673c42b42002871ddd02d37))
### Bug Fixes
* redirection not correctly if signing in with email code ([e5ec264](https://github.com/pocket-id/pocket-id/commit/e5ec264bfd535752565bcc107099a9df5cb8aba7))
* typo in account settings ([#307](https://github.com/pocket-id/pocket-id/issues/307)) ([c822192](https://github.com/pocket-id/pocket-id/commit/c8221921245deb3008f655740d1a9460dcdab2fc))
## [](https://github.com/pocket-id/pocket-id/compare/v0.36.0...v) (2025-03-10)
### Features
* **account:** add ability to sign in with login code ([#271](https://github.com/pocket-id/pocket-id/issues/271)) ([eb1426e](https://github.com/pocket-id/pocket-id/commit/eb1426ed2684b5ddd185db247a8e082b28dfd014))
* increase default item count per page ([a9713cf](https://github.com/pocket-id/pocket-id/commit/a9713cf6a1e3c879dc773889b7983e51bbe3c45b))
### Bug Fixes
* add back setup page ([6a8dd84](https://github.com/pocket-id/pocket-id/commit/6a8dd84ca9396ff3369385af22f7e1f081bec2b2))
* add timeout to update check ([04efc36](https://github.com/pocket-id/pocket-id/commit/04efc3611568a0b0127b542b8cc252d9e783af46))
* make sorting consistent around tables ([8e344f1](https://github.com/pocket-id/pocket-id/commit/8e344f1151628581b637692a1de0e48e7235a22d))
## [](https://github.com/pocket-id/pocket-id/compare/v0.35.6...v) (2025-03-06)
### Features
* display groups on the account page ([#296](https://github.com/pocket-id/pocket-id/issues/296)) ([0f14a93](https://github.com/pocket-id/pocket-id/commit/0f14a93e1d6a723b0994ba475b04702646f04464))
* enable sd_notify support ([#277](https://github.com/pocket-id/pocket-id/issues/277)) ([91f254c](https://github.com/pocket-id/pocket-id/commit/91f254c7bb067646c42424c5c62ebcd90a0c8792))
### Bug Fixes
* default sorting on tables ([#299](https://github.com/pocket-id/pocket-id/issues/299)) ([ff34e3b](https://github.com/pocket-id/pocket-id/commit/ff34e3b925321c80e9d7d42d0fd50e397d198435))
## [](https://github.com/pocket-id/pocket-id/compare/v0.35.5...v) (2025-03-03)
### Bug Fixes
* support `LOGIN` authentication method for SMTP ([#292](https://github.com/pocket-id/pocket-id/issues/292)) ([2d733fc](https://github.com/pocket-id/pocket-id/commit/2d733fc79faefca23d54b22768029c3ba3427410))
## [](https://github.com/pocket-id/pocket-id/compare/v0.35.4...v) (2025-03-03)
### Bug Fixes
* profile picture orientation if image is rotated with EXIF ([1026ee4](https://github.com/pocket-id/pocket-id/commit/1026ee4f5b5c7fda78b65c94a5d0f899525defd1))
## [](https://github.com/pocket-id/pocket-id/compare/v0.35.3...v) (2025-03-01)
### Bug Fixes
* add `groups` scope and claim to well known endpoint ([4bafee4](https://github.com/pocket-id/pocket-id/commit/4bafee4f58f5a76898cf66d6192916d405eea389))
* profile picture of other user can't be updated ([#273](https://github.com/pocket-id/pocket-id/issues/273)) ([ef25f6b](https://github.com/pocket-id/pocket-id/commit/ef25f6b6b84b52f1310d366d40aa3769a6fe9bef))
* support POST for OIDC userinfo endpoint ([1652cc6](https://github.com/pocket-id/pocket-id/commit/1652cc65f3f966d018d81a1ae22abb5ff1b4c47b))
## [](https://github.com/pocket-id/pocket-id/compare/v0.35.2...v) (2025-02-25)
### Bug Fixes
* add option to manually select SMTP TLS method ([#268](https://github.com/pocket-id/pocket-id/issues/268)) ([01a9de0](https://github.com/pocket-id/pocket-id/commit/01a9de0b04512c62d0f223de33d711f93c49b9cc))
* **ldap:** sync error if LDAP user collides with an existing user ([fde951b](https://github.com/pocket-id/pocket-id/commit/fde951b543281fedf9f602abae26b50881e3d157))
## [](https://github.com/pocket-id/pocket-id/compare/v0.35.1...v) (2025-02-24)
### Bug Fixes
* delete profile picture if user gets deleted ([9a167d4](https://github.com/pocket-id/pocket-id/commit/9a167d4076872e5e3e5d78d2a66ef7203ca5261b))
* updating profile picture of other user updates own profile picture ([887c5e4](https://github.com/pocket-id/pocket-id/commit/887c5e462a50c8fb579ca6804f1a643d8af78fe8))
## [](https://github.com/pocket-id/pocket-id/compare/v0.35.0...v) (2025-02-22)
### Bug Fixes
* add validation that `PUBLIC_APP_URL` can't contain a path ([a6ae7ae](https://github.com/pocket-id/pocket-id/commit/a6ae7ae28713f7fc8018ae2aa7572986df3e1a5b))
* binary profile picture can't be imported from LDAP ([840a672](https://github.com/pocket-id/pocket-id/commit/840a672fc35ca8476caf86d7efaba9d54bce86aa))
## [](https://github.com/pocket-id/pocket-id/compare/v0.34.0...v) (2025-02-19)
### Features
* add ability to upload a profile picture ([#244](https://github.com/pocket-id/pocket-id/issues/244)) ([652ee6a](https://github.com/pocket-id/pocket-id/commit/652ee6ad5d6c46f0d35c955ff7bb9bdac6240ca6))
### Bug Fixes
* app config strings starting with a number are parsed incorrectly ([816c198](https://github.com/pocket-id/pocket-id/commit/816c198a42c189cb1f2d94885d2e3623e47e2848))
* emails do not get rendered correctly in Gmail ([dca9e7a](https://github.com/pocket-id/pocket-id/commit/dca9e7a11a3ba5d3b43a937f11cb9d16abad2db5))
## [](https://github.com/pocket-id/pocket-id/compare/v0.33.0...v) (2025-02-16)

View File

@@ -31,8 +31,15 @@ Before you submit the pull request for review please ensure that
- You run `npm run format` to format the code
## Setup project
Pocket ID consists of a frontend, backend and a reverse proxy. There are two ways to get the development environment setup:
Pocket ID consists of a frontend, backend and a reverse proxy.
## 1. Using DevContainers
1. Make sure you have [Dev Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension installed
2. Clone and open the repo in VS Code
3. VS Code will detect .devcontainer and will prompt you to open the folder in devcontainer
4. If the auto prompt does not work, hit `F1` and select `Dev Containers: Open Folder in Container.`, then select the pocket-id repo root folder and it'll open in container.
## 2. Manual
### Backend
@@ -42,7 +49,7 @@ The backend is built with [Gin](https://gin-gonic.com) and written in Go.
1. Open the `backend` folder
2. Copy the `.env.example` file to `.env` and change the `APP_ENV` to `development`
3. Start the backend with `go run cmd/main.go`
3. Start the backend with `go run -tags e2etest ./cmd`
### Frontend
@@ -63,6 +70,10 @@ Run `caddy run --config reverse-proxy/Caddyfile` in the root folder.
You're all set!
## Debugging
1. The VS Code is currently setup to auto launch caddy on opening the folder. (Defined in [tasks.json](.vscode/tasks.json))
2. Press `F5` to start a debug session. This will launch both frontend and backend and attach debuggers to those process. (Defined in [launch.json](.vscode/launch.json))
### Testing
We are using [Playwright](https://playwright.dev) for end-to-end testing.

View File

@@ -1,3 +1,6 @@
# Tags passed to "go build"
ARG BUILD_TAGS=""
# Stage 1: Build Frontend
FROM node:22-alpine AS frontend-builder
WORKDIR /app/frontend
@@ -9,6 +12,7 @@ RUN npm prune --production
# Stage 2: Build Backend
FROM golang:1.23-alpine AS backend-builder
ARG BUILD_TAGS
WORKDIR /app/backend
COPY ./backend/go.mod ./backend/go.sum ./
RUN go mod download
@@ -17,7 +21,12 @@ RUN apk add --no-cache gcc musl-dev
COPY ./backend ./
WORKDIR /app/backend/cmd
RUN CGO_ENABLED=1 GOOS=linux go build -o /app/backend/pocket-id-backend .
RUN CGO_ENABLED=1 \
GOOS=linux \
go build \
-tags "${BUILD_TAGS}" \
-o /app/backend/pocket-id-backend \
.
# Stage 3: Production Image
FROM node:22-alpine
@@ -35,10 +44,10 @@ COPY --from=frontend-builder /app/frontend/package.json ./frontend/package.json
COPY --from=backend-builder /app/backend/pocket-id-backend ./backend/pocket-id-backend
COPY ./scripts ./scripts
RUN chmod +x ./scripts/*.sh
RUN chmod +x ./scripts/**/*.sh
EXPOSE 80
ENV APP_ENV=production
ENTRYPOINT ["sh", "./scripts/docker/create-user.sh"]
CMD ["sh", "./scripts/docker/entrypoint.sh"]
CMD ["sh", "./scripts/docker/entrypoint.sh"]

64
backend/.golangci.yml Normal file
View File

@@ -0,0 +1,64 @@
version: "2"
run:
tests: true
timeout: 5m
linters:
default: none
enable:
- asasalint
- asciicheck
- bidichk
- bodyclose
- contextcheck
- copyloopvar
- durationcheck
- errcheck
- errchkjson
- errorlint
- exhaustive
- gocheckcompilerdirectives
- gochecksumtype
- gocognit
- gocritic
- gosec
- gosmopolitan
- govet
- ineffassign
- loggercheck
- makezero
- musttag
- nilerr
- nilnesserr
- noctx
- protogetter
- reassign
- recvcheck
- rowserrcheck
- spancheck
- sqlclosecheck
- staticcheck
- testifylint
- unused
- usestdlibvars
- zerologlint
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
paths:
- third_party$
- builtin$
- examples$
- internal/service/test_service.go
formatters:
enable:
- goimports
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$

View File

@@ -1,22 +1,28 @@
module github.com/pocket-id/pocket-id/backend
go 1.23.1
go 1.23.7
require (
github.com/caarlos0/env/v11 v11.3.1
github.com/disintegration/imageorient v0.0.0-20180920195336-8147d86e83ec
github.com/disintegration/imaging v1.6.2
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21
github.com/emersion/go-smtp v0.21.3
github.com/fxamacker/cbor/v2 v2.7.0
github.com/gin-gonic/gin v1.10.0
github.com/go-co-op/gocron/v2 v2.15.0
github.com/go-ldap/ldap/v3 v3.4.10
github.com/go-playground/validator/v10 v10.24.0
github.com/go-webauthn/webauthn v0.11.2
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/golang-migrate/migrate/v4 v4.18.2
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1
github.com/mileusna/useragent v1.3.5
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.2
golang.org/x/crypto v0.32.0
github.com/stretchr/testify v1.10.0
golang.org/x/crypto v0.36.0
golang.org/x/image v0.24.0
golang.org/x/time v0.9.0
gorm.io/driver/postgres v1.5.11
gorm.io/driver/sqlite v1.5.7
@@ -28,6 +34,9 @@ require (
github.com/bytedance/sonic v1.12.8 // indirect
github.com/bytedance/sonic/loader v0.2.3 // indirect
github.com/cloudwego/base64x v0.1.5 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect
github.com/disintegration/gift v1.1.2 // indirect
github.com/gabriel-vasile/mimetype v1.4.8 // indirect
github.com/gin-contrib/sse v1.0.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.7 // indirect
@@ -35,6 +44,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-webauthn/x v0.1.16 // indirect
github.com/goccy/go-json v0.10.4 // indirect
github.com/golang-jwt/jwt/v5 v5.2.2 // indirect
github.com/google/go-tpm v0.9.3 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
@@ -49,6 +59,10 @@ require (
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.24 // indirect
@@ -56,17 +70,19 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/segmentio/asm v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/x448/float16 v0.8.4 // indirect
go.uber.org/atomic v1.11.0 // indirect
golang.org/x/arch v0.13.0 // indirect
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 // indirect
golang.org/x/net v0.34.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.29.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/net v0.36.0 // indirect
golang.org/x/sync v0.12.0 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.36.4 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -20,8 +20,16 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40=
github.com/dhui/dktest v0.4.4 h1:+I4s6JRE1yGuqflzwqG+aIaMdgXIorCf5P98JnaAWa8=
github.com/dhui/dktest v0.4.4/go.mod h1:4+22R4lgsdAXrDyaH4Nqx2JEz2hLp49MqQmm9HLCQhM=
github.com/disintegration/gift v1.1.2 h1:9ZyHJr+kPamiH10FX3Pynt1AxFUob812bU9Wt4GMzhs=
github.com/disintegration/gift v1.1.2/go.mod h1:Jh2i7f7Q2BM7Ezno3PhfezbR1xpUg9dUg3/RlKGr4HI=
github.com/disintegration/imageorient v0.0.0-20180920195336-8147d86e83ec h1:YrB6aVr9touOt75I9O1SiancmR2GMg45U9UYf0gtgWg=
github.com/disintegration/imageorient v0.0.0-20180920195336-8147d86e83ec/go.mod h1:K0KBFIr1gWu/C1Gp10nFAcAE4hsB7JxE6OgLijrJ8Sk=
github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c=
github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v27.2.0+incompatible h1:Rk9nIVdfH3+Vz4cyI/uhbINhEZ/oLmc+CBXmH6fbNk4=
@@ -30,6 +38,10 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 h1:OJyUGMJTzHTd1XQp98QTaHernxMYzRaOasRir9hUlFQ=
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
github.com/emersion/go-smtp v0.21.3 h1:7uVwagE8iPYE48WhNsng3RRpCUpFvNl39JGNSIyGVMY=
github.com/emersion/go-smtp v0.21.3/go.mod h1:qm27SGYgoIPRot6ubfQ/GpiPy/g3PaZAVRxiO/sDUgQ=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E=
@@ -66,8 +78,8 @@ github.com/goccy/go-json v0.10.4 h1:JSwxQzIqKfmFX1swYPpUThQZp/Ka4wzJdK0LWVytLPM=
github.com/goccy/go-json v0.10.4/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8=
github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
@@ -127,6 +139,16 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1 h1:pzDjP9dSONCFQC/AE3mWUnHILGiYPiMKzQIS+weKJXA=
github.com/lestrrat-go/httprc/v3 v3.0.0-beta1/go.mod h1:wdsgouffPvWPEYh8t7PRH/PidR5sfVqt0na4Nhj60Ms=
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1 h1:Iqjb8JvWjh34Jv8DeM2wQ1aG5fzFBzwQu7rlqwuJB0I=
github.com/lestrrat-go/jwx/v3 v3.0.0-beta1/go.mod h1:ak32WoNtHE0aLowVWBcCvXngcAnW4tuC0YhFwOr/kwc=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
@@ -166,12 +188,15 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8=
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
@@ -207,10 +232,13 @@ golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliY
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA=
golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU=
golang.org/x/image v0.0.0-20191009234506-e7c1f5e7dbb8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.24.0 h1:AN7zRgVsbvmTfNyqIbbOraYL8mSwcKncEj8ofjgzcMQ=
golang.org/x/image v0.24.0/go.mod h1:4b/ITuLfqYq1hqZcjofwctIhi7sZh2WaCjvsBNjjya8=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
@@ -227,16 +255,17 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA=
golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -249,8 +278,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -268,8 +297,9 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -8,8 +8,12 @@ import (
func Bootstrap() {
initApplicationImages()
migrateConfigDBConnstring()
db := newDatabase()
appConfigService := service.NewAppConfigService(db)
migrateKey()
initRouter(db, appConfigService)
}

View File

@@ -0,0 +1,34 @@
package bootstrap
import (
"log"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
// Performs the migration of the database connection string
// See: https://github.com/pocket-id/pocket-id/pull/388
func migrateConfigDBConnstring() {
switch common.EnvConfig.DbProvider {
case common.DbProviderSqlite:
// Check if we're using the deprecated SqliteDBPath env var
if common.EnvConfig.SqliteDBPath != "" {
connString := "file:" + common.EnvConfig.SqliteDBPath + "?_journal_mode=WAL&_busy_timeout=2500&_txlock=immediate"
common.EnvConfig.DbConnectionString = connString
common.EnvConfig.SqliteDBPath = ""
log.Printf("[WARN] Env var 'SQLITE_DB_PATH' is deprecated - use 'DB_CONNECTION_STRING' instead with the value: '%s'", connString)
}
case common.DbProviderPostgres:
// Check if we're using the deprecated PostgresConnectionString alias
if common.EnvConfig.PostgresConnectionString != "" {
common.EnvConfig.DbConnectionString = common.EnvConfig.PostgresConnectionString
common.EnvConfig.PostgresConnectionString = ""
log.Print("[WARN] Env var 'POSTGRES_CONNECTION_STRING' is deprecated - use 'DB_CONNECTION_STRING' instead with the same value")
}
default:
// We don't do anything here in the default case
// This is an error, but will be handled later on
}
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"log"
"os"
"strings"
"time"
"github.com/golang-migrate/migrate/v4"
@@ -38,6 +39,7 @@ func newDatabase() (db *gorm.DB) {
case common.DbProviderPostgres:
driver, err = postgresMigrate.WithInstance(sqlDb, &postgresMigrate.Config{})
default:
// Should never happen at this point
log.Fatalf("unsupported database provider: %s", common.EnvConfig.DbProvider)
}
if err != nil {
@@ -56,17 +58,17 @@ func migrateDatabase(driver database.Driver) error {
// Use the embedded migrations
source, err := iofs.New(resources.FS, "migrations/"+string(common.EnvConfig.DbProvider))
if err != nil {
return fmt.Errorf("failed to create embedded migration source: %v", err)
return fmt.Errorf("failed to create embedded migration source: %w", err)
}
m, err := migrate.NewWithInstance("iofs", source, "pocket-id", driver)
if err != nil {
return fmt.Errorf("failed to create migration instance: %v", err)
return fmt.Errorf("failed to create migration instance: %w", err)
}
err = m.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return fmt.Errorf("failed to apply migrations: %v", err)
return fmt.Errorf("failed to apply migrations: %w", err)
}
return nil
@@ -78,9 +80,18 @@ func connectDatabase() (db *gorm.DB, err error) {
// Choose the correct database provider
switch common.EnvConfig.DbProvider {
case common.DbProviderSqlite:
dialector = sqlite.Open(common.EnvConfig.SqliteDBPath)
if common.EnvConfig.DbConnectionString == "" {
return nil, errors.New("missing required env var 'DB_CONNECTION_STRING' for SQLite database")
}
if !strings.HasPrefix(common.EnvConfig.DbConnectionString, "file:") {
return nil, errors.New("invalid value for env var 'DB_CONNECTION_STRING': does not begin with 'file:'")
}
dialector = sqlite.Open(common.EnvConfig.DbConnectionString)
case common.DbProviderPostgres:
dialector = postgres.Open(common.EnvConfig.PostgresConnectionString)
if common.EnvConfig.DbConnectionString == "" {
return nil, errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database")
}
dialector = postgres.Open(common.EnvConfig.DbConnectionString)
default:
return nil, fmt.Errorf("unsupported database provider: %s", common.EnvConfig.DbProvider)
}
@@ -91,14 +102,14 @@ func connectDatabase() (db *gorm.DB, err error) {
Logger: getLogger(),
})
if err == nil {
break
} else {
log.Printf("Attempt %d: Failed to initialize database. Retrying...", i)
time.Sleep(3 * time.Second)
return db, nil
}
log.Printf("Attempt %d: Failed to initialize database. Retrying...", i)
time.Sleep(3 * time.Second)
}
return db, err
return nil, err
}
func getLogger() logger.Interface {

View File

@@ -0,0 +1,21 @@
//go:build e2etest
package bootstrap
import (
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/controller"
"github.com/pocket-id/pocket-id/backend/internal/service"
)
// When building for E2E tests, add the e2etest controller
func init() {
registerTestControllers = []func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService){
func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService) {
testService := service.NewTestService(db, appConfigService, jwtService)
controller.NewTestController(apiGroup, testService)
},
}
}

View File

@@ -0,0 +1,136 @@
package bootstrap
import (
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"fmt"
"log"
"os"
"path/filepath"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
const (
privateKeyFilePem = "jwt_private_key.pem"
)
func migrateKey() {
err := migrateKeyInternal(common.EnvConfig.KeysPath)
if err != nil {
log.Fatalf("failed to perform migration of keys: %v", err)
}
}
func migrateKeyInternal(basePath string) error {
// First, check if there's already a JWK stored
jwkPath := filepath.Join(basePath, service.PrivateKeyFile)
ok, err := utils.FileExists(jwkPath)
if err != nil {
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
}
if ok {
// There's already a key as JWK, so we don't do anything else here
return nil
}
// Check if there's a PEM file
pemPath := filepath.Join(basePath, privateKeyFilePem)
ok, err = utils.FileExists(pemPath)
if err != nil {
return fmt.Errorf("failed to check if private key file (PEM) exists at path '%s': %w", pemPath, err)
}
if !ok {
// No file to migrate, return
return nil
}
// Load and validate the key
key, err := loadKeyPEM(pemPath)
if err != nil {
return fmt.Errorf("failed to load private key file (PEM) at path '%s': %w", pemPath, err)
}
err = service.ValidateKey(key)
if err != nil {
return fmt.Errorf("key object is invalid: %w", err)
}
// Save the key as JWK
err = service.SaveKeyJWK(key, jwkPath)
if err != nil {
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
}
// Finally, delete the PEM file
err = os.Remove(pemPath)
if err != nil {
return fmt.Errorf("failed to remove migrated key at path '%s': %w", pemPath, err)
}
return nil
}
func loadKeyPEM(path string) (jwk.Key, error) {
// Load the key from disk and parse it
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read key data: %w", err)
}
key, err := jwk.ParseKey(data, jwk.WithPEM(true))
if err != nil {
return nil, fmt.Errorf("failed to parse key: %w", err)
}
// Populate the key ID using the "legacy" algorithm
keyId, err := generateKeyID(key)
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
err = key.Set(jwk.KeyIDKey, keyId)
if err != nil {
return nil, fmt.Errorf("failed to set key ID: %w", err)
}
// Populate other required fields
_ = key.Set(jwk.KeyUsageKey, service.KeyUsageSigning)
service.EnsureAlgInKey(key)
return key, nil
}
// generateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key's PKIX-serialized structure.
// This is used for legacy keys, imported from PEM.
func generateKeyID(key jwk.Key) (string, error) {
// Export the public key and serialize it to PKIX (not in a PEM block)
// This is for backwards-compatibility with the algorithm used before the switch to JWK
pubKey, err := key.PublicKey()
if err != nil {
return "", fmt.Errorf("failed to get public key: %w", err)
}
var pubKeyRaw any
err = jwk.Export(pubKey, &pubKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to export public key: %w", err)
}
pubASN1, err := x509.MarshalPKIXPublicKey(pubKeyRaw)
if err != nil {
return "", fmt.Errorf("failed to marshal public key: %w", err)
}
// Compute SHA-256 hash of the public key
hash := sha256.New()
hash.Write(pubASN1)
hashed := hash.Sum(nil)
// Truncate the hash to the first 8 bytes for a shorter Key ID
shortHash := hashed[:8]
// Return Base64 encoded truncated hash as Key ID
return base64.RawURLEncoding.EncodeToString(shortHash), nil
}

View File

@@ -0,0 +1,190 @@
package bootstrap
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"os"
"path/filepath"
"testing"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
func TestMigrateKey(t *testing.T) {
// Create a temporary directory for testing
tempDir := t.TempDir()
t.Run("no keys exist", func(t *testing.T) {
// Test when no keys exist
err := migrateKeyInternal(tempDir)
require.NoError(t, err)
})
t.Run("jwk already exists", func(t *testing.T) {
// Create a JWK file
jwkPath := filepath.Join(tempDir, service.PrivateKeyFile)
key, err := createTestRSAKey()
require.NoError(t, err)
err = service.SaveKeyJWK(key, jwkPath)
require.NoError(t, err)
// Run migration - should do nothing
err = migrateKeyInternal(tempDir)
require.NoError(t, err)
// Check the file still exists
exists, err := utils.FileExists(jwkPath)
require.NoError(t, err)
assert.True(t, exists)
// Delete for next test
err = os.Remove(jwkPath)
require.NoError(t, err)
})
t.Run("migrate pem to jwk", func(t *testing.T) {
// Create a PEM file
pemPath := filepath.Join(tempDir, privateKeyFilePem)
jwkPath := filepath.Join(tempDir, service.PrivateKeyFile)
// Generate RSA key and save as PEM
createRSAPrivateKeyPEM(t, pemPath)
// Run migration
err := migrateKeyInternal(tempDir)
require.NoError(t, err)
// Check PEM file is gone
exists, err := utils.FileExists(pemPath)
require.NoError(t, err)
assert.False(t, exists)
// Check JWK file exists
exists, err = utils.FileExists(jwkPath)
require.NoError(t, err)
assert.True(t, exists)
// Verify the JWK can be loaded
data, err := os.ReadFile(jwkPath)
require.NoError(t, err)
_, err = jwk.ParseKey(data)
require.NoError(t, err)
})
}
func TestLoadKeyPEM(t *testing.T) {
// Create a temporary directory for testing
tempDir := t.TempDir()
t.Run("successfully load PEM key", func(t *testing.T) {
pemPath := filepath.Join(tempDir, "test_key.pem")
// Generate RSA key and save as PEM
createRSAPrivateKeyPEM(t, pemPath)
// Load the key
key, err := loadKeyPEM(pemPath)
require.NoError(t, err)
// Verify key properties
assert.NotEmpty(t, key)
// Check key ID is set
var keyID string
err = key.Get(jwk.KeyIDKey, &keyID)
require.NoError(t, err)
assert.NotEmpty(t, keyID)
// Check algorithm is set
var alg jwa.SignatureAlgorithm
err = key.Get(jwk.AlgorithmKey, &alg)
require.NoError(t, err)
assert.NotEmpty(t, alg)
// Check key usage is set
var keyUsage string
err = key.Get(jwk.KeyUsageKey, &keyUsage)
require.NoError(t, err)
assert.Equal(t, service.KeyUsageSigning, keyUsage)
})
t.Run("file not found", func(t *testing.T) {
key, err := loadKeyPEM(filepath.Join(tempDir, "nonexistent.pem"))
require.Error(t, err)
assert.Nil(t, key)
})
t.Run("invalid file content", func(t *testing.T) {
invalidPath := filepath.Join(tempDir, "invalid.pem")
err := os.WriteFile(invalidPath, []byte("not a valid PEM"), 0600)
require.NoError(t, err)
key, err := loadKeyPEM(invalidPath)
require.Error(t, err)
assert.Nil(t, key)
})
}
func TestGenerateKeyID(t *testing.T) {
key, err := createTestRSAKey()
require.NoError(t, err)
keyID, err := generateKeyID(key)
require.NoError(t, err)
// Key ID should be non-empty
assert.NotEmpty(t, keyID)
// Generate another key ID to prove it depends on the key
key2, err := createTestRSAKey()
require.NoError(t, err)
keyID2, err := generateKeyID(key2)
require.NoError(t, err)
// The two key IDs should be different
assert.NotEqual(t, keyID, keyID2)
}
// Helper functions
func createTestRSAKey() (jwk.Key, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, err
}
key, err := jwk.Import(privateKey)
if err != nil {
return nil, err
}
return key, nil
}
// createRSAPrivateKeyPEM generates an RSA private key and returns its PEM-encoded form
func createRSAPrivateKeyPEM(t *testing.T, pemPath string) ([]byte, *rsa.PrivateKey) {
// Generate RSA key
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
// Encode to PEM format
pemData := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privKey),
})
err = os.WriteFile(pemPath, pemData, 0600)
require.NoError(t, err)
return pemData, privKey
}

View File

@@ -2,6 +2,7 @@ package bootstrap
import (
"log"
"net"
"time"
"github.com/gin-gonic/gin"
@@ -10,10 +11,17 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/job"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils/systemd"
"golang.org/x/time/rate"
"gorm.io/gorm"
)
// This is used to register additional controllers for tests
var registerTestControllers []func(apiGroup *gin.RouterGroup, db *gorm.DB, appConfigService *service.AppConfigService, jwtService *service.JwtService)
// @title Pocket ID API
// @version 1
// @description API for Pocket ID
func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
// Set the appropriate Gin mode based on the environment
switch common.EnvConfig.AppEnv {
@@ -41,9 +49,9 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
userService := service.NewUserService(db, jwtService, auditLogService, emailService, appConfigService)
customClaimService := service.NewCustomClaimService(db)
oidcService := service.NewOidcService(db, jwtService, appConfigService, auditLogService, customClaimService)
testService := service.NewTestService(db, appConfigService, jwtService)
userGroupService := service.NewUserGroupService(db, appConfigService)
ldapService := service.NewLdapService(db, appConfigService, userService, userGroupService)
apiKeyService := service.NewApiKeyService(db)
rateLimitMiddleware := middleware.NewRateLimitMiddleware()
@@ -51,36 +59,50 @@ func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
r.Use(middleware.NewCorsMiddleware().Add())
r.Use(middleware.NewErrorHandlerMiddleware().Add())
r.Use(rateLimitMiddleware.Add(rate.Every(time.Second), 60))
r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false))
job.RegisterLdapJobs(ldapService, appConfigService)
job.RegisterDbCleanupJobs(db)
// Initialize middleware for specific routes
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService, false)
authMiddleware := middleware.NewAuthMiddleware(apiKeyService, jwtService)
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
// Set up API routes
apiGroup := r.Group("/api")
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, appConfigService)
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService)
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService, appConfigService)
controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService, emailService, ldapService)
controller.NewAuditLogController(apiGroup, auditLogService, jwtAuthMiddleware)
controller.NewUserGroupController(apiGroup, jwtAuthMiddleware, userGroupService)
controller.NewCustomClaimController(apiGroup, jwtAuthMiddleware, customClaimService)
controller.NewApiKeyController(apiGroup, authMiddleware, apiKeyService)
controller.NewWebauthnController(apiGroup, authMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, appConfigService)
controller.NewOidcController(apiGroup, authMiddleware, fileSizeLimitMiddleware, oidcService, jwtService)
controller.NewUserController(apiGroup, authMiddleware, middleware.NewRateLimitMiddleware(), userService, appConfigService)
controller.NewAppConfigController(apiGroup, authMiddleware, appConfigService, emailService, ldapService)
controller.NewAuditLogController(apiGroup, auditLogService, authMiddleware)
controller.NewUserGroupController(apiGroup, authMiddleware, userGroupService)
controller.NewCustomClaimController(apiGroup, authMiddleware, customClaimService)
// Add test controller in non-production environments
if common.EnvConfig.AppEnv != "production" {
controller.NewTestController(apiGroup, testService)
for _, f := range registerTestControllers {
f(apiGroup, db, appConfigService, jwtService)
}
}
// Set up base routes
baseGroup := r.Group("/")
controller.NewWellKnownController(baseGroup, jwtService)
// Run the server
if err := r.Run(common.EnvConfig.Host + ":" + common.EnvConfig.Port); err != nil {
// Get the listener
l, err := net.Listen("tcp", common.EnvConfig.Host+":"+common.EnvConfig.Port)
if err != nil {
log.Fatal(err)
}
// Notify systemd that we are ready
if err := systemd.SdNotifyReady(); err != nil {
log.Println("Unable to notify systemd that the service is ready: ", err)
// continue to serve anyway since it's not that important
}
// Serve requests
if err := r.RunListener(l); err != nil {
log.Fatal(err)
}
}

View File

@@ -2,6 +2,7 @@ package common
import (
"log"
"net/url"
"github.com/caarlos0/env/v11"
_ "github.com/joho/godotenv/autoload"
@@ -19,9 +20,11 @@ type EnvConfigSchema struct {
AppEnv string `env:"APP_ENV"`
AppURL string `env:"PUBLIC_APP_URL"`
DbProvider DbProvider `env:"DB_PROVIDER"`
SqliteDBPath string `env:"SQLITE_DB_PATH"`
PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"`
DbConnectionString string `env:"DB_CONNECTION_STRING"`
SqliteDBPath string `env:"SQLITE_DB_PATH"` // Deprecated: use "DB_CONNECTION_STRING" instead
PostgresConnectionString string `env:"POSTGRES_CONNECTION_STRING"` // Deprecated: use "DB_CONNECTION_STRING" instead
UploadPath string `env:"UPLOAD_PATH"`
KeysPath string `env:"KEYS_PATH"`
Port string `env:"BACKEND_PORT"`
Host string `env:"HOST"`
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY"`
@@ -33,9 +36,11 @@ type EnvConfigSchema struct {
var EnvConfig = &EnvConfigSchema{
AppEnv: "production",
DbProvider: "sqlite",
SqliteDBPath: "data/pocket-id.db",
DbConnectionString: "file:data/pocket-id.db?_journal_mode=WAL&_busy_timeout=2500&_txlock=immediate",
SqliteDBPath: "",
PostgresConnectionString: "",
UploadPath: "data/uploads",
KeysPath: "data/keys",
AppURL: "http://localhost",
Port: "8080",
Host: "0.0.0.0",
@@ -49,16 +54,26 @@ func init() {
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
log.Fatal(err)
}
// Validate the environment variables
if EnvConfig.DbProvider != DbProviderSqlite && EnvConfig.DbProvider != DbProviderPostgres {
switch EnvConfig.DbProvider {
case DbProviderSqlite:
if EnvConfig.DbConnectionString == "" {
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for SQLite database")
}
case DbProviderPostgres:
if EnvConfig.DbConnectionString == "" {
log.Fatal("Missing required env var 'DB_CONNECTION_STRING' for Postgres database")
}
default:
log.Fatal("Invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
}
if EnvConfig.DbProvider == DbProviderPostgres && EnvConfig.PostgresConnectionString == "" {
log.Fatal("Missing POSTGRES_CONNECTION_STRING environment variable")
parsedAppUrl, err := url.Parse(EnvConfig.AppURL)
if err != nil {
log.Fatal("PUBLIC_APP_URL is not a valid URL")
}
if EnvConfig.DbProvider == DbProviderSqlite && EnvConfig.SqliteDBPath == "" {
log.Fatal("Missing SQLITE_DB_PATH environment variable")
if parsedAppUrl.Path != "" {
log.Fatal("PUBLIC_APP_URL must not contain a path")
}
}

View File

@@ -94,6 +94,11 @@ type NotSignedInError struct{}
func (e *NotSignedInError) Error() string { return "You are not signed in" }
func (e *NotSignedInError) HttpStatusCode() int { return http.StatusUnauthorized }
type MissingAccessToken struct{}
func (e *MissingAccessToken) Error() string { return "Missing access token" }
func (e *MissingAccessToken) HttpStatusCode() int { return http.StatusUnauthorized }
type MissingPermissionError struct{}
func (e *MissingPermissionError) Error() string {
@@ -211,3 +216,72 @@ func (e *UiConfigDisabledError) Error() string {
return "The configuration can't be changed since the UI configuration is disabled"
}
func (e *UiConfigDisabledError) HttpStatusCode() int { return http.StatusForbidden }
type InvalidUUIDError struct{}
func (e *InvalidUUIDError) Error() string {
return "Invalid UUID"
}
type InvalidEmailError struct{}
type OneTimeAccessDisabledError struct{}
func (e *OneTimeAccessDisabledError) Error() string {
return "One-time access is disabled"
}
func (e *OneTimeAccessDisabledError) HttpStatusCode() int { return http.StatusBadRequest }
type InvalidAPIKeyError struct{}
func (e *InvalidAPIKeyError) Error() string {
return "Invalid Api Key"
}
type NoAPIKeyProvidedError struct{}
func (e *NoAPIKeyProvidedError) Error() string {
return "No API Key Provided"
}
type APIKeyNotFoundError struct{}
func (e *APIKeyNotFoundError) Error() string {
return "API Key Not Found"
}
type APIKeyExpirationDateError struct{}
func (e *APIKeyExpirationDateError) Error() string {
return "API Key expiration time must be in the future"
}
type OidcInvalidRefreshTokenError struct{}
func (e *OidcInvalidRefreshTokenError) Error() string {
return "refresh token is invalid or expired"
}
func (e *OidcInvalidRefreshTokenError) HttpStatusCode() int {
return http.StatusBadRequest
}
type OidcMissingRefreshTokenError struct{}
func (e *OidcMissingRefreshTokenError) Error() string {
return "refresh token is required"
}
func (e *OidcMissingRefreshTokenError) HttpStatusCode() int {
return http.StatusBadRequest
}
type OidcMissingAuthorizationCodeError struct{}
func (e *OidcMissingAuthorizationCodeError) Error() string {
return "authorization code is required"
}
func (e *OidcMissingAuthorizationCodeError) HttpStatusCode() int {
return http.StatusBadRequest
}

View File

@@ -0,0 +1,125 @@
package controller
import (
"net/http"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
"github.com/pocket-id/pocket-id/backend/internal/service"
)
// swag init -g cmd/main.go -o ./docs/swagger --parseDependency
// ApiKeyController manages API keys for authenticated users
type ApiKeyController struct {
apiKeyService *service.ApiKeyService
}
// NewApiKeyController creates a new controller for API key management
// @Summary API key management controller
// @Description Initializes API endpoints for managing API keys
// @Tags API Keys
func NewApiKeyController(group *gin.RouterGroup, authMiddleware *middleware.AuthMiddleware, apiKeyService *service.ApiKeyService) {
uc := &ApiKeyController{apiKeyService: apiKeyService}
apiKeyGroup := group.Group("/api-keys")
apiKeyGroup.Use(authMiddleware.WithAdminNotRequired().Add())
{
apiKeyGroup.GET("", uc.listApiKeysHandler)
apiKeyGroup.POST("", uc.createApiKeyHandler)
apiKeyGroup.DELETE("/:id", uc.revokeApiKeyHandler)
}
}
// listApiKeysHandler godoc
// @Summary List API keys
// @Description Get a paginated list of API keys belonging to the current user
// @Tags API Keys
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Success 200 {object} dto.Paginated[dto.ApiKeyDto]
// @Router /api/api-keys [get]
func (c *ApiKeyController) listApiKeysHandler(ctx *gin.Context) {
userID := ctx.GetString("userID")
var sortedPaginationRequest utils.SortedPaginationRequest
if err := ctx.ShouldBindQuery(&sortedPaginationRequest); err != nil {
_ = ctx.Error(err)
return
}
apiKeys, pagination, err := c.apiKeyService.ListApiKeys(userID, sortedPaginationRequest)
if err != nil {
_ = ctx.Error(err)
return
}
var apiKeysDto []dto.ApiKeyDto
if err := dto.MapStructList(apiKeys, &apiKeysDto); err != nil {
_ = ctx.Error(err)
return
}
ctx.JSON(http.StatusOK, dto.Paginated[dto.ApiKeyDto]{
Data: apiKeysDto,
Pagination: pagination,
})
}
// createApiKeyHandler godoc
// @Summary Create API key
// @Description Create a new API key for the current user
// @Tags API Keys
// @Param api_key body dto.ApiKeyCreateDto true "API key information"
// @Success 201 {object} dto.ApiKeyResponseDto "Created API key with token"
// @Router /api/api-keys [post]
func (c *ApiKeyController) createApiKeyHandler(ctx *gin.Context) {
userID := ctx.GetString("userID")
var input dto.ApiKeyCreateDto
if err := ctx.ShouldBindJSON(&input); err != nil {
_ = ctx.Error(err)
return
}
apiKey, token, err := c.apiKeyService.CreateApiKey(userID, input)
if err != nil {
_ = ctx.Error(err)
return
}
var apiKeyDto dto.ApiKeyDto
if err := dto.MapStruct(apiKey, &apiKeyDto); err != nil {
_ = ctx.Error(err)
return
}
ctx.JSON(http.StatusCreated, dto.ApiKeyResponseDto{
ApiKey: apiKeyDto,
Token: token,
})
}
// revokeApiKeyHandler godoc
// @Summary Revoke API key
// @Description Revoke (delete) an existing API key by ID
// @Tags API Keys
// @Param id path string true "API Key ID"
// @Success 204 "No Content"
// @Router /api/api-keys/{id} [delete]
func (c *ApiKeyController) revokeApiKeyHandler(ctx *gin.Context) {
userID := ctx.GetString("userID")
apiKeyID := ctx.Param("id")
if err := c.apiKeyService.RevokeApiKey(userID, apiKeyID); err != nil {
_ = ctx.Error(err)
return
}
ctx.Status(http.StatusNoContent)
}

View File

@@ -3,6 +3,7 @@ package controller
import (
"fmt"
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common"
@@ -12,9 +13,13 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
// NewAppConfigController creates a new controller for application configuration endpoints
// @Summary Create a new application configuration controller
// @Description Initialize routes for application configuration
// @Tags Application Configuration
func NewAppConfigController(
group *gin.RouterGroup,
jwtAuthMiddleware *middleware.JwtAuthMiddleware,
authMiddleware *middleware.AuthMiddleware,
appConfigService *service.AppConfigService,
emailService *service.EmailService,
ldapService *service.LdapService,
@@ -26,18 +31,18 @@ func NewAppConfigController(
ldapService: ldapService,
}
group.GET("/application-configuration", acc.listAppConfigHandler)
group.GET("/application-configuration/all", jwtAuthMiddleware.Add(true), acc.listAllAppConfigHandler)
group.PUT("/application-configuration", acc.updateAppConfigHandler)
group.GET("/application-configuration/all", authMiddleware.Add(), acc.listAllAppConfigHandler)
group.PUT("/application-configuration", authMiddleware.Add(), acc.updateAppConfigHandler)
group.GET("/application-configuration/logo", acc.getLogoHandler)
group.GET("/application-configuration/background-image", acc.getBackgroundImageHandler)
group.GET("/application-configuration/favicon", acc.getFaviconHandler)
group.PUT("/application-configuration/logo", jwtAuthMiddleware.Add(true), acc.updateLogoHandler)
group.PUT("/application-configuration/favicon", jwtAuthMiddleware.Add(true), acc.updateFaviconHandler)
group.PUT("/application-configuration/background-image", jwtAuthMiddleware.Add(true), acc.updateBackgroundImageHandler)
group.PUT("/application-configuration/logo", authMiddleware.Add(), acc.updateLogoHandler)
group.PUT("/application-configuration/favicon", authMiddleware.Add(), acc.updateFaviconHandler)
group.PUT("/application-configuration/background-image", authMiddleware.Add(), acc.updateBackgroundImageHandler)
group.POST("/application-configuration/test-email", jwtAuthMiddleware.Add(true), acc.testEmailHandler)
group.POST("/application-configuration/sync-ldap", jwtAuthMiddleware.Add(true), acc.syncLdapHandler)
group.POST("/application-configuration/test-email", authMiddleware.Add(), acc.testEmailHandler)
group.POST("/application-configuration/sync-ldap", authMiddleware.Add(), acc.syncLdapHandler)
}
type AppConfigController struct {
@@ -46,62 +51,100 @@ type AppConfigController struct {
ldapService *service.LdapService
}
// listAppConfigHandler godoc
// @Summary List public application configurations
// @Description Get all public application configurations
// @Tags Application Configuration
// @Accept json
// @Produce json
// @Success 200 {array} dto.PublicAppConfigVariableDto
// @Failure 500 {object} object "{"error": "error message"}"
// @Router /application-configuration [get]
func (acc *AppConfigController) listAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(false)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var configVariablesDto []dto.PublicAppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(200, configVariablesDto)
}
// listAllAppConfigHandler godoc
// @Summary List all application configurations
// @Description Get all application configurations including private ones
// @Tags Application Configuration
// @Accept json
// @Produce json
// @Success 200 {array} dto.AppConfigVariableDto
// @Security BearerAuth
// @Router /application-configuration/all [get]
func (acc *AppConfigController) listAllAppConfigHandler(c *gin.Context) {
configuration, err := acc.appConfigService.ListAppConfig(true)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(200, configVariablesDto)
}
// updateAppConfigHandler godoc
// @Summary Update application configurations
// @Description Update application configuration settings
// @Tags Application Configuration
// @Accept json
// @Produce json
// @Param body body dto.AppConfigUpdateDto true "Application Configuration"
// @Success 200 {array} dto.AppConfigVariableDto
// @Security BearerAuth
// @Router /api/application-configuration [put]
func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
var input dto.AppConfigUpdateDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
savedConfigVariables, err := acc.appConfigService.UpdateAppConfig(input)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var configVariablesDto []dto.AppConfigVariableDto
if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, configVariablesDto)
}
// getLogoHandler godoc
// @Summary Get logo image
// @Description Get the logo image for the application
// @Tags Application Configuration
// @Param light query boolean false "Light mode logo (true) or dark mode logo (false)"
// @Produce image/png
// @Produce image/jpeg
// @Produce image/svg+xml
// @Success 200 {file} binary "Logo image"
// @Router /api/application-configuration/logo [get]
func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
lightLogo := c.DefaultQuery("light", "true") == "true"
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
var imageName string
var imageType string
@@ -117,17 +160,44 @@ func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
acc.getImage(c, imageName, imageType)
}
// getFaviconHandler godoc
// @Summary Get favicon
// @Description Get the favicon for the application
// @Tags Application Configuration
// @Produce image/x-icon
// @Success 200 {file} binary "Favicon image"
// @Failure 404 {object} object "{"error": "File not found"}"
// @Router /api/application-configuration/favicon [get]
func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
acc.getImage(c, "favicon", "ico")
}
// getBackgroundImageHandler godoc
// @Summary Get background image
// @Description Get the background image for the application
// @Tags Application Configuration
// @Produce image/png
// @Produce image/jpeg
// @Success 200 {file} binary "Background image"
// @Failure 404 {object} object "{"error": "File not found"}"
// @Router /api/application-configuration/background-image [get]
func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
acc.getImage(c, "background", imageType)
}
// updateLogoHandler godoc
// @Summary Update logo
// @Description Update the application logo
// @Tags Application Configuration
// @Accept multipart/form-data
// @Param light query boolean false "Light mode logo (true) or dark mode logo (false)"
// @Param file formData file true "Logo image file"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/application-configuration/logo [put]
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
lightLogo := c.DefaultQuery("light", "true") == "true"
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
var imageName string
var imageType string
@@ -143,26 +213,45 @@ func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
acc.updateImage(c, imageName, imageType)
}
// updateFaviconHandler godoc
// @Summary Update favicon
// @Description Update the application favicon
// @Tags Application Configuration
// @Accept multipart/form-data
// @Param file formData file true "Favicon file (.ico)"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/application-configuration/favicon [put]
func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
fileType := utils.GetFileExtension(file.Filename)
if fileType != "ico" {
c.Error(&common.WrongFileTypeError{ExpectedFileType: ".ico"})
_ = c.Error(&common.WrongFileTypeError{ExpectedFileType: ".ico"})
return
}
acc.updateImage(c, "favicon", "ico")
}
// updateBackgroundImageHandler godoc
// @Summary Update background image
// @Description Update the application background image
// @Tags Application Configuration
// @Accept multipart/form-data
// @Param file formData file true "Background image file"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/application-configuration/background-image [put]
func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
acc.updateImage(c, "background", imageType)
}
// getImage is a helper function to serve image files
func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) {
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
mimeType := utils.GetImageMimeType(imageType)
@@ -171,37 +260,53 @@ func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType
c.File(imagePath)
}
// updateImage is a helper function to update image files
func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) {
file, err := c.FormFile("file")
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// syncLdapHandler godoc
// @Summary Synchronize LDAP
// @Description Manually trigger LDAP synchronization
// @Tags Application Configuration
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/application-configuration/sync-ldap [post]
func (acc *AppConfigController) syncLdapHandler(c *gin.Context) {
err := acc.ldapService.SyncAll()
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// testEmailHandler godoc
// @Summary Send test email
// @Description Send a test email to verify email configuration
// @Tags Application Configuration
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/application-configuration/test-email [post]
func (acc *AppConfigController) testEmailHandler(c *gin.Context) {
userID := c.GetString("userID")
err := acc.emailService.SendTestEmail(userID)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}

View File

@@ -11,22 +11,36 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/service"
)
func NewAuditLogController(group *gin.RouterGroup, auditLogService *service.AuditLogService, jwtAuthMiddleware *middleware.JwtAuthMiddleware) {
// NewAuditLogController creates a new controller for audit log management
// @Summary Audit log controller
// @Description Initializes API endpoints for accessing audit logs
// @Tags Audit Logs
func NewAuditLogController(group *gin.RouterGroup, auditLogService *service.AuditLogService, authMiddleware *middleware.AuthMiddleware) {
alc := AuditLogController{
auditLogService: auditLogService,
}
group.GET("/audit-logs", jwtAuthMiddleware.Add(false), alc.listAuditLogsForUserHandler)
group.GET("/audit-logs", authMiddleware.WithAdminNotRequired().Add(), alc.listAuditLogsForUserHandler)
}
type AuditLogController struct {
auditLogService *service.AuditLogService
}
// listAuditLogsForUserHandler godoc
// @Summary List audit logs
// @Description Get a paginated list of audit logs for the current user
// @Tags Audit Logs
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Success 200 {object} dto.Paginated[dto.AuditLogDto]
// @Router /api/audit-logs [get]
func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
@@ -35,7 +49,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
// Fetch audit logs for the user
logs, pagination, err := alc.auditLogService.ListAuditLogsForUser(userID, sortedPaginationRequest)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
@@ -43,7 +57,7 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
var logsDtos []dto.AuditLogDto
err = dto.MapStructList(logs, &logsDtos)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
@@ -53,8 +67,8 @@ func (alc *AuditLogController) listAuditLogsForUserHandler(c *gin.Context) {
logsDtos[i] = logsDto
}
c.JSON(http.StatusOK, gin.H{
"data": logsDtos,
"pagination": pagination,
c.JSON(http.StatusOK, dto.Paginated[dto.AuditLogDto]{
Data: logsDtos,
Pagination: pagination,
})
}

View File

@@ -9,69 +9,110 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/service"
)
func NewCustomClaimController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, customClaimService *service.CustomClaimService) {
// NewCustomClaimController creates a new controller for custom claim management
// @Summary Custom claim management controller
// @Description Initializes all custom claim-related API endpoints
// @Tags Custom Claims
func NewCustomClaimController(group *gin.RouterGroup, authMiddleware *middleware.AuthMiddleware, customClaimService *service.CustomClaimService) {
wkc := &CustomClaimController{customClaimService: customClaimService}
group.GET("/custom-claims/suggestions", jwtAuthMiddleware.Add(true), wkc.getSuggestionsHandler)
group.PUT("/custom-claims/user/:userId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserHandler)
group.PUT("/custom-claims/user-group/:userGroupId", jwtAuthMiddleware.Add(true), wkc.UpdateCustomClaimsForUserGroupHandler)
customClaimsGroup := group.Group("/custom-claims")
customClaimsGroup.Use(authMiddleware.Add())
{
customClaimsGroup.GET("/suggestions", wkc.getSuggestionsHandler)
customClaimsGroup.PUT("/user/:userId", wkc.UpdateCustomClaimsForUserHandler)
customClaimsGroup.PUT("/user-group/:userGroupId", wkc.UpdateCustomClaimsForUserGroupHandler)
}
}
type CustomClaimController struct {
customClaimService *service.CustomClaimService
}
// getSuggestionsHandler godoc
// @Summary Get custom claim suggestions
// @Description Get a list of suggested custom claim names
// @Tags Custom Claims
// @Produce json
// @Success 200 {array} string "List of suggested custom claim names"
// @Failure 401 {object} object "Unauthorized"
// @Failure 403 {object} object "Forbidden"
// @Failure 500 {object} object "Internal server error"
// @Security BearerAuth
// @Router /api/custom-claims/suggestions [get]
func (ccc *CustomClaimController) getSuggestionsHandler(c *gin.Context) {
claims, err := ccc.customClaimService.GetSuggestions()
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, claims)
}
// UpdateCustomClaimsForUserHandler godoc
// @Summary Update custom claims for a user
// @Description Update or create custom claims for a specific user
// @Tags Custom Claims
// @Accept json
// @Produce json
// @Param userId path string true "User ID"
// @Param claims body []dto.CustomClaimCreateDto true "List of custom claims to set for the user"
// @Success 200 {array} dto.CustomClaimDto "Updated custom claims"
// @Router /api/custom-claims/user/{userId} [put]
func (ccc *CustomClaimController) UpdateCustomClaimsForUserHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
userId := c.Param("userId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUser(userId, input)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var customClaimsDto []dto.CustomClaimDto
if err := dto.MapStructList(claims, &customClaimsDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, customClaimsDto)
}
// UpdateCustomClaimsForUserGroupHandler godoc
// @Summary Update custom claims for a user group
// @Description Update or create custom claims for a specific user group
// @Tags Custom Claims
// @Accept json
// @Produce json
// @Param userGroupId path string true "User Group ID"
// @Param claims body []dto.CustomClaimCreateDto true "List of custom claims to set for the user group"
// @Success 200 {array} dto.CustomClaimDto "Updated custom claims"
// @Security BearerAuth
// @Router /api/custom-claims/user-group/{userGroupId} [put]
func (ccc *CustomClaimController) UpdateCustomClaimsForUserGroupHandler(c *gin.Context) {
var input []dto.CustomClaimCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
userId := c.Param("userGroupId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userId, input)
userGroupId := c.Param("userGroupId")
claims, err := ccc.customClaimService.UpdateCustomClaimsForUserGroup(userGroupId, input)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var customClaimsDto []dto.CustomClaimDto
if err := dto.MapStructList(claims, &customClaimsDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}

View File

@@ -1,9 +1,12 @@
//go:build e2etest
package controller
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/service"
)
@@ -19,22 +22,22 @@ type TestController struct {
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
if err := tc.TestService.ResetDatabase(); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
if err := tc.TestService.ResetApplicationImages(); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
if err := tc.TestService.SeedDatabase(); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
if err := tc.TestService.ResetAppConfig(); err != nil {
c.Error(err)
_ = c.Error(err)
return
}

View File

@@ -1,13 +1,14 @@
package controller
import (
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
"log"
"net/http"
"net/url"
"strings"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/middleware"
@@ -15,29 +16,35 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService, jwtService *service.JwtService) {
// NewOidcController creates a new controller for OIDC related endpoints
// @Summary OIDC controller
// @Description Initializes all OIDC-related API endpoints for authentication and client management
// @Tags OIDC
func NewOidcController(group *gin.RouterGroup, authMiddleware *middleware.AuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService, jwtService *service.JwtService) {
oc := &OidcController{oidcService: oidcService, jwtService: jwtService}
group.POST("/oidc/authorize", jwtAuthMiddleware.Add(false), oc.authorizeHandler)
group.POST("/oidc/authorization-required", jwtAuthMiddleware.Add(false), oc.authorizationConfirmationRequiredHandler)
group.POST("/oidc/authorize", authMiddleware.WithAdminNotRequired().Add(), oc.authorizeHandler)
group.POST("/oidc/authorization-required", authMiddleware.WithAdminNotRequired().Add(), oc.authorizationConfirmationRequiredHandler)
group.POST("/oidc/token", oc.createTokensHandler)
group.GET("/oidc/userinfo", oc.userInfoHandler)
group.POST("/oidc/end-session", oc.EndSessionHandler)
group.GET("/oidc/end-session", oc.EndSessionHandler)
group.POST("/oidc/userinfo", oc.userInfoHandler)
group.POST("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler)
group.GET("/oidc/end-session", authMiddleware.WithSuccessOptional().Add(), oc.EndSessionHandler)
group.GET("/oidc/clients", jwtAuthMiddleware.Add(true), oc.listClientsHandler)
group.POST("/oidc/clients", jwtAuthMiddleware.Add(true), oc.createClientHandler)
group.GET("/oidc/clients/:id", oc.getClientHandler)
group.PUT("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.updateClientHandler)
group.DELETE("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.deleteClientHandler)
group.GET("/oidc/clients", authMiddleware.Add(), oc.listClientsHandler)
group.POST("/oidc/clients", authMiddleware.Add(), oc.createClientHandler)
group.GET("/oidc/clients/:id", authMiddleware.Add(), oc.getClientHandler)
group.GET("/oidc/clients/:id/meta", oc.getClientMetaDataHandler)
group.PUT("/oidc/clients/:id", authMiddleware.Add(), oc.updateClientHandler)
group.DELETE("/oidc/clients/:id", authMiddleware.Add(), oc.deleteClientHandler)
group.PUT("/oidc/clients/:id/allowed-user-groups", jwtAuthMiddleware.Add(true), oc.updateAllowedUserGroupsHandler)
group.POST("/oidc/clients/:id/secret", jwtAuthMiddleware.Add(true), oc.createClientSecretHandler)
group.PUT("/oidc/clients/:id/allowed-user-groups", authMiddleware.Add(), oc.updateAllowedUserGroupsHandler)
group.POST("/oidc/clients/:id/secret", authMiddleware.Add(), oc.createClientSecretHandler)
group.GET("/oidc/clients/:id/logo", oc.getClientLogoHandler)
group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler)
group.POST("/oidc/clients/:id/logo", jwtAuthMiddleware.Add(true), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler)
group.POST("/oidc/clients/:id/logo", authMiddleware.Add(), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler)
}
type OidcController struct {
@@ -45,16 +52,26 @@ type OidcController struct {
jwtService *service.JwtService
}
// authorizeHandler godoc
// @Summary Authorize OIDC client
// @Description Start the OIDC authorization process for a client
// @Tags OIDC
// @Accept json
// @Produce json
// @Param request body dto.AuthorizeOidcClientRequestDto true "Authorization request parameters"
// @Success 200 {object} dto.AuthorizeOidcClientResponseDto "Authorization code and callback URL"
// @Security BearerAuth
// @Router /api/oidc/authorize [post]
func (oc *OidcController) authorizeHandler(c *gin.Context) {
var input dto.AuthorizeOidcClientRequestDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
@@ -66,30 +83,64 @@ func (oc *OidcController) authorizeHandler(c *gin.Context) {
c.JSON(http.StatusOK, response)
}
// authorizationConfirmationRequiredHandler godoc
// @Summary Check if authorization confirmation is required
// @Description Check if the user needs to confirm authorization for the client
// @Tags OIDC
// @Accept json
// @Produce json
// @Param request body dto.AuthorizationRequiredDto true "Authorization check parameters"
// @Success 200 {object} object "{ \"authorizationRequired\": true/false }"
// @Security BearerAuth
// @Router /api/oidc/authorization-required [post]
func (oc *OidcController) authorizationConfirmationRequiredHandler(c *gin.Context) {
var input dto.AuthorizationRequiredDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
hasAuthorizedClient, err := oc.oidcService.HasAuthorizedClient(input.ClientID, c.GetString("userID"), input.Scope)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, gin.H{"authorizationRequired": !hasAuthorizedClient})
}
// createTokensHandler godoc
// @Summary Create OIDC tokens
// @Description Exchange authorization code or refresh token for access tokens
// @Tags OIDC
// @Produce json
// @Param client_id formData string false "Client ID (if not using Basic Auth)"
// @Param client_secret formData string false "Client secret (if not using Basic Auth)"
// @Param code formData string false "Authorization code (required for 'authorization_code' grant)"
// @Param grant_type formData string true "Grant type ('authorization_code' or 'refresh_token')"
// @Param code_verifier formData string false "PKCE code verifier (for authorization_code with PKCE)"
// @Param refresh_token formData string false "Refresh token (required for 'refresh_token' grant)"
// @Success 200 {object} dto.OidcTokenResponseDto "Token response with access_token and optional id_token and refresh_token"
// @Router /api/oidc/token [post]
func (oc *OidcController) createTokensHandler(c *gin.Context) {
// Disable cors for this endpoint
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
var input dto.OidcCreateTokensDto
if err := c.ShouldBind(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
// Validate that code is provided for authorization_code grant type
if input.GrantType == "authorization_code" && input.Code == "" {
_ = c.Error(&common.OidcMissingAuthorizationCodeError{})
return
}
// Validate that refresh_token is provided for refresh_token grant type
if input.GrantType == "refresh_token" && input.RefreshToken == "" {
_ = c.Error(&common.OidcMissingRefreshTokenError{})
return
}
@@ -101,46 +152,104 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
clientID, clientSecret, _ = c.Request.BasicAuth()
}
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret, input.CodeVerifier)
idToken, accessToken, refreshToken, expiresIn, err := oc.oidcService.CreateTokens(
input.Code,
input.GrantType,
clientID,
clientSecret,
input.CodeVerifier,
input.RefreshToken,
)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, gin.H{"id_token": idToken, "access_token": accessToken, "token_type": "Bearer"})
response := dto.OidcTokenResponseDto{
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: expiresIn,
}
// Include ID token only for authorization_code grant
if idToken != "" {
response.IdToken = idToken
}
// Include refresh token if generated
if refreshToken != "" {
response.RefreshToken = refreshToken
}
c.JSON(http.StatusOK, response)
}
// userInfoHandler godoc
// @Summary Get user information
// @Description Get user information based on the access token
// @Tags OIDC
// @Accept json
// @Produce json
// @Success 200 {object} object "User claims based on requested scopes"
// @Security OAuth2AccessToken
// @Router /api/oidc/userinfo [get]
func (oc *OidcController) userInfoHandler(c *gin.Context) {
token := strings.Split(c.GetHeader("Authorization"), " ")[1]
jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token)
if err != nil {
c.Error(err)
_, authToken, ok := strings.Cut(c.GetHeader("Authorization"), " ")
if !ok || authToken == "" {
_ = c.Error(&common.MissingAccessToken{})
return
}
userID := jwtClaims.Subject
clientId := jwtClaims.Audience[0]
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId)
token, err := oc.jwtService.VerifyOauthAccessToken(authToken)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
userID, ok := token.Subject()
if !ok {
_ = c.Error(&common.TokenInvalidError{})
return
}
clientID, ok := token.Audience()
if !ok || len(clientID) != 1 {
_ = c.Error(&common.TokenInvalidError{})
return
}
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientID[0])
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, claims)
}
// EndSessionHandler godoc
// @Summary End OIDC session
// @Description End user session and handle OIDC logout
// @Tags OIDC
// @Accept application/x-www-form-urlencoded
// @Produce html
// @Param id_token_hint query string false "ID token"
// @Param post_logout_redirect_uri query string false "URL to redirect to after logout"
// @Param state query string false "State parameter to include in the redirect"
// @Success 302 "Redirect to post-logout URL or application logout page"
// @Router /api/oidc/end-session [get]
func (oc *OidcController) EndSessionHandler(c *gin.Context) {
var input dto.OidcLogoutDto
// Bind query parameters to the struct
if c.Request.Method == http.MethodGet {
switch c.Request.Method {
case http.MethodGet:
if err := c.ShouldBindQuery(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
} else if c.Request.Method == http.MethodPost {
case http.MethodPost:
// Bind form parameters to the struct
if err := c.ShouldBind(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
}
@@ -166,128 +275,228 @@ func (oc *OidcController) EndSessionHandler(c *gin.Context) {
c.Redirect(http.StatusFound, logoutCallbackURL.String())
}
// EndSessionHandler godoc (POST method)
// @Summary End OIDC session (POST method)
// @Description End user session and handle OIDC logout using POST
// @Tags OIDC
// @Accept application/x-www-form-urlencoded
// @Produce html
// @Param id_token_hint formData string false "ID token"
// @Param post_logout_redirect_uri formData string false "URL to redirect to after logout"
// @Param state formData string false "State parameter to include in the redirect"
// @Success 302 "Redirect to post-logout URL or application logout page"
// @Router /api/oidc/end-session [post]
func (oc *OidcController) EndSessionHandlerPost(c *gin.Context) {
// Implementation is the same as GET
}
// getClientMetaDataHandler godoc
// @Summary Get client metadata
// @Description Get OIDC client metadata for discovery and configuration
// @Tags OIDC
// @Produce json
// @Param id path string true "Client ID"
// @Success 200 {object} dto.OidcClientMetaDataDto "Client metadata"
// @Router /api/oidc/clients/{id}/meta [get]
func (oc *OidcController) getClientMetaDataHandler(c *gin.Context) {
clientId := c.Param("id")
client, err := oc.oidcService.GetClient(clientId)
if err != nil {
_ = c.Error(err)
return
}
clientDto := dto.OidcClientMetaDataDto{}
err = dto.MapStruct(client, &clientDto)
if err == nil {
c.JSON(http.StatusOK, clientDto)
return
}
_ = c.Error(err)
}
// getClientHandler godoc
// @Summary Get OIDC client
// @Description Get detailed information about an OIDC client
// @Tags OIDC
// @Produce json
// @Param id path string true "Client ID"
// @Success 200 {object} dto.OidcClientWithAllowedUserGroupsDto "Client information"
// @Security BearerAuth
// @Router /api/oidc/clients/{id} [get]
func (oc *OidcController) getClientHandler(c *gin.Context) {
clientId := c.Param("id")
client, err := oc.oidcService.GetClient(clientId)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
// Return a different DTO based on the user's role
if c.GetBool("userIsAdmin") {
clientDto := dto.OidcClientWithAllowedUserGroupsDto{}
err = dto.MapStruct(client, &clientDto)
if err == nil {
c.JSON(http.StatusOK, clientDto)
return
}
} else {
clientDto := dto.PublicOidcClientDto{}
err = dto.MapStruct(client, &clientDto)
if err == nil {
c.JSON(http.StatusOK, clientDto)
return
}
clientDto := dto.OidcClientWithAllowedUserGroupsDto{}
err = dto.MapStruct(client, &clientDto)
if err == nil {
c.JSON(http.StatusOK, clientDto)
return
}
c.Error(err)
_ = c.Error(err)
}
// listClientsHandler godoc
// @Summary List OIDC clients
// @Description Get a paginated list of OIDC clients with optional search and sorting
// @Tags OIDC
// @Param search query string false "Search term to filter clients by name"
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("name")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.OidcClientDto]
// @Security BearerAuth
// @Router /api/oidc/clients [get]
func (oc *OidcController) listClientsHandler(c *gin.Context) {
searchTerm := c.Query("search")
var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
clients, pagination, err := oc.oidcService.ListClients(searchTerm, sortedPaginationRequest)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var clientsDto []dto.OidcClientDto
if err := dto.MapStructList(clients, &clientsDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, gin.H{
"data": clientsDto,
"pagination": pagination,
c.JSON(http.StatusOK, dto.Paginated[dto.OidcClientDto]{
Data: clientsDto,
Pagination: pagination,
})
}
// createClientHandler godoc
// @Summary Create OIDC client
// @Description Create a new OIDC client
// @Tags OIDC
// @Accept json
// @Produce json
// @Param client body dto.OidcClientCreateDto true "Client information"
// @Success 201 {object} dto.OidcClientWithAllowedUserGroupsDto "Created client"
// @Security BearerAuth
// @Router /api/oidc/clients [post]
func (oc *OidcController) createClientHandler(c *gin.Context) {
var input dto.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
client, err := oc.oidcService.CreateClient(input, c.GetString("userID"))
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var clientDto dto.OidcClientWithAllowedUserGroupsDto
if err := dto.MapStruct(client, &clientDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusCreated, clientDto)
}
// deleteClientHandler godoc
// @Summary Delete OIDC client
// @Description Delete an OIDC client by ID
// @Tags OIDC
// @Param id path string true "Client ID"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/oidc/clients/{id} [delete]
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
err := oc.oidcService.DeleteClient(c.Param("id"))
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// updateClientHandler godoc
// @Summary Update OIDC client
// @Description Update an existing OIDC client
// @Tags OIDC
// @Accept json
// @Produce json
// @Param id path string true "Client ID"
// @Param client body dto.OidcClientCreateDto true "Client information"
// @Success 200 {object} dto.OidcClientWithAllowedUserGroupsDto "Updated client"
// @Security BearerAuth
// @Router /api/oidc/clients/{id} [put]
func (oc *OidcController) updateClientHandler(c *gin.Context) {
var input dto.OidcClientCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
client, err := oc.oidcService.UpdateClient(c.Param("id"), input)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var clientDto dto.OidcClientWithAllowedUserGroupsDto
if err := dto.MapStruct(client, &clientDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, clientDto)
}
// createClientSecretHandler godoc
// @Summary Create client secret
// @Description Generate a new secret for an OIDC client
// @Tags OIDC
// @Produce json
// @Param id path string true "Client ID"
// @Success 200 {object} object "{ \"secret\": \"string\" }"
// @Security BearerAuth
// @Router /api/oidc/clients/{id}/secret [post]
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
secret, err := oc.oidcService.CreateClientSecret(c.Param("id"))
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, gin.H{"secret": secret})
}
// getClientLogoHandler godoc
// @Summary Get client logo
// @Description Get the logo image for an OIDC client
// @Tags OIDC
// @Produce image/png
// @Produce image/jpeg
// @Produce image/svg+xml
// @Param id path string true "Client ID"
// @Success 200 {file} binary "Logo image"
// @Router /api/oidc/clients/{id}/logo [get]
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id"))
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
@@ -295,48 +504,77 @@ func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
c.File(imagePath)
}
// updateClientLogoHandler godoc
// @Summary Update client logo
// @Description Upload or update the logo for an OIDC client
// @Tags OIDC
// @Accept multipart/form-data
// @Param id path string true "Client ID"
// @Param file formData file true "Logo image file (PNG, JPG, or SVG, max 2MB)"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/oidc/clients/{id}/logo [post]
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
file, err := c.FormFile("file")
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// deleteClientLogoHandler godoc
// @Summary Delete client logo
// @Description Delete the logo for an OIDC client
// @Tags OIDC
// @Param id path string true "Client ID"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/oidc/clients/{id}/logo [delete]
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
err := oc.oidcService.DeleteClientLogo(c.Param("id"))
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// updateAllowedUserGroupsHandler godoc
// @Summary Update allowed user groups
// @Description Update the user groups allowed to access an OIDC client
// @Tags OIDC
// @Accept json
// @Produce json
// @Param id path string true "Client ID"
// @Param groups body dto.OidcUpdateAllowedUserGroupsDto true "User group IDs"
// @Success 200 {object} dto.OidcClientDto "Updated client"
// @Security BearerAuth
// @Router /api/oidc/clients/{id}/allowed-user-groups [put]
func (oc *OidcController) updateAllowedUserGroupsHandler(c *gin.Context) {
var input dto.OidcUpdateAllowedUserGroupsDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
oidcClient, err := oc.oidcService.UpdateAllowedUserGroups(c.Param("id"), input)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var oidcClientDto dto.OidcClientDto
if err := dto.MapStruct(oidcClient, &oidcClientDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}

View File

@@ -2,7 +2,6 @@ package controller
import (
"net/http"
"strconv"
"time"
"github.com/pocket-id/pocket-id/backend/internal/utils/cookie"
@@ -16,24 +15,40 @@ import (
"golang.org/x/time/rate"
)
func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, userService *service.UserService, appConfigService *service.AppConfigService) {
// NewUserController creates a new controller for user management endpoints
// @Summary User management controller
// @Description Initializes all user-related API endpoints
// @Tags Users
func NewUserController(group *gin.RouterGroup, authMiddleware *middleware.AuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, userService *service.UserService, appConfigService *service.AppConfigService) {
uc := UserController{
userService: userService,
appConfigService: appConfigService,
}
group.GET("/users", jwtAuthMiddleware.Add(true), uc.listUsersHandler)
group.GET("/users/me", jwtAuthMiddleware.Add(false), uc.getCurrentUserHandler)
group.GET("/users/:id", jwtAuthMiddleware.Add(true), uc.getUserHandler)
group.POST("/users", jwtAuthMiddleware.Add(true), uc.createUserHandler)
group.PUT("/users/:id", jwtAuthMiddleware.Add(true), uc.updateUserHandler)
group.PUT("/users/me", jwtAuthMiddleware.Add(false), uc.updateCurrentUserHandler)
group.DELETE("/users/:id", jwtAuthMiddleware.Add(true), uc.deleteUserHandler)
group.GET("/users", authMiddleware.Add(), uc.listUsersHandler)
group.GET("/users/me", authMiddleware.WithAdminNotRequired().Add(), uc.getCurrentUserHandler)
group.GET("/users/:id", authMiddleware.Add(), uc.getUserHandler)
group.POST("/users", authMiddleware.Add(), uc.createUserHandler)
group.PUT("/users/:id", authMiddleware.Add(), uc.updateUserHandler)
group.GET("/users/:id/groups", authMiddleware.Add(), uc.getUserGroupsHandler)
group.PUT("/users/me", authMiddleware.WithAdminNotRequired().Add(), uc.updateCurrentUserHandler)
group.DELETE("/users/:id", authMiddleware.Add(), uc.deleteUserHandler)
group.POST("/users/:id/one-time-access-token", jwtAuthMiddleware.Add(true), uc.createOneTimeAccessTokenHandler)
group.PUT("/users/:id/user-groups", authMiddleware.Add(), uc.updateUserGroups)
group.GET("/users/:id/profile-picture.png", uc.getUserProfilePictureHandler)
group.PUT("/users/:id/profile-picture", authMiddleware.Add(), uc.updateUserProfilePictureHandler)
group.PUT("/users/me/profile-picture", authMiddleware.WithAdminNotRequired().Add(), uc.updateCurrentUserProfilePictureHandler)
group.POST("/users/me/one-time-access-token", authMiddleware.WithAdminNotRequired().Add(), uc.createOwnOneTimeAccessTokenHandler)
group.POST("/users/:id/one-time-access-token", authMiddleware.Add(), uc.createAdminOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler)
group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler)
group.POST("/one-time-access-email", rateLimitMiddleware.Add(rate.Every(10*time.Minute), 3), uc.requestOneTimeAccessEmailHandler)
group.DELETE("/users/:id/profile-picture", authMiddleware.Add(), uc.resetUserProfilePictureHandler)
group.DELETE("/users/me/profile-picture", authMiddleware.WithAdminNotRequired().Add(), uc.resetCurrentUserProfilePictureHandler)
}
type UserController struct {
@@ -41,183 +56,406 @@ type UserController struct {
appConfigService *service.AppConfigService
}
// getUserGroupsHandler godoc
// @Summary Get user groups
// @Description Retrieve all groups a specific user belongs to
// @Tags Users,User Groups
// @Param id path string true "User ID"
// @Success 200 {array} dto.UserGroupDtoWithUsers
// @Router /api/users/{id}/groups [get]
func (uc *UserController) getUserGroupsHandler(c *gin.Context) {
userID := c.Param("id")
groups, err := uc.userService.GetUserGroups(userID)
if err != nil {
_ = c.Error(err)
return
}
var groupsDto []dto.UserGroupDtoWithUsers
if err := dto.MapStructList(groups, &groupsDto); err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, groupsDto)
}
// listUsersHandler godoc
// @Summary List users
// @Description Get a paginated list of users with optional search and sorting
// @Tags Users
// @Param search query string false "Search term to filter users"
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("created_at")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("desc")
// @Success 200 {object} dto.Paginated[dto.UserDto]
// @Router /api/users [get]
func (uc *UserController) listUsersHandler(c *gin.Context) {
searchTerm := c.Query("search")
var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
users, pagination, err := uc.userService.ListUsers(searchTerm, sortedPaginationRequest)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var usersDto []dto.UserDto
if err := dto.MapStructList(users, &usersDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, gin.H{
"data": usersDto,
"pagination": pagination,
c.JSON(http.StatusOK, dto.Paginated[dto.UserDto]{
Data: usersDto,
Pagination: pagination,
})
}
// getUserHandler godoc
// @Summary Get user by ID
// @Description Retrieve detailed information about a specific user
// @Tags Users
// @Param id path string true "User ID"
// @Success 200 {object} dto.UserDto
// @Router /api/users/{id} [get]
func (uc *UserController) getUserHandler(c *gin.Context) {
user, err := uc.userService.GetUser(c.Param("id"))
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, userDto)
}
// getCurrentUserHandler godoc
// @Summary Get current user
// @Description Retrieve information about the currently authenticated user
// @Tags Users
// @Success 200 {object} dto.UserDto
// @Router /api/users/me [get]
func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
user, err := uc.userService.GetUser(c.GetString("userID"))
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, userDto)
}
// deleteUserHandler godoc
// @Summary Delete user
// @Description Delete a specific user by ID
// @Tags Users
// @Param id path string true "User ID"
// @Success 204 "No Content"
// @Router /api/users/{id} [delete]
func (uc *UserController) deleteUserHandler(c *gin.Context) {
if err := uc.userService.DeleteUser(c.Param("id")); err != nil {
c.Error(err)
if err := uc.userService.DeleteUser(c.Param("id"), false); err != nil {
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// createUserHandler godoc
// @Summary Create user
// @Description Create a new user
// @Tags Users
// @Param user body dto.UserCreateDto true "User information"
// @Success 201 {object} dto.UserDto
// @Router /api/users [post]
func (uc *UserController) createUserHandler(c *gin.Context) {
var input dto.UserCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
user, err := uc.userService.CreateUser(input)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusCreated, userDto)
}
// updateUserHandler godoc
// @Summary Update user
// @Description Update an existing user by ID
// @Tags Users
// @Param id path string true "User ID"
// @Param user body dto.UserCreateDto true "User information"
// @Success 200 {object} dto.UserDto
// @Router /api/users/{id} [put]
func (uc *UserController) updateUserHandler(c *gin.Context) {
uc.updateUser(c, false)
}
// updateCurrentUserHandler godoc
// @Summary Update current user
// @Description Update the currently authenticated user's information
// @Tags Users
// @Param user body dto.UserCreateDto true "User information"
// @Success 200 {object} dto.UserDto
// @Router /api/users/me [put]
func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
if uc.appConfigService.DbConfig.AllowOwnAccountEdit.Value != "true" {
c.Error(&common.AccountEditNotAllowedError{})
if !uc.appConfigService.DbConfig.AllowOwnAccountEdit.IsTrue() {
_ = c.Error(&common.AccountEditNotAllowedError{})
return
}
uc.updateUser(c, true)
}
func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
var input dto.OneTimeAccessTokenCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
return
}
// getUserProfilePictureHandler godoc
// @Summary Get user profile picture
// @Description Retrieve a specific user's profile picture
// @Tags Users
// @Produce image/png
// @Param id path string true "User ID"
// @Success 200 {file} binary "PNG image"
// @Router /api/users/{id}/profile-picture.png [get]
func (uc *UserController) getUserProfilePictureHandler(c *gin.Context) {
userID := c.Param("id")
token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
picture, size, err := uc.userService.GetProfilePicture(userID)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusCreated, gin.H{"token": token})
c.Header("Cache-Control", "public, max-age=300")
c.DataFromReader(http.StatusOK, size, "image/png", picture, nil)
}
func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) {
var input dto.OneTimeAccessEmailDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
// updateUserProfilePictureHandler godoc
// @Summary Update user profile picture
// @Description Update a specific user's profile picture
// @Tags Users
// @Accept multipart/form-data
// @Produce json
// @Param id path string true "User ID"
// @Param file formData file true "Profile picture image file (PNG, JPG, or JPEG)"
// @Success 204 "No Content"
// @Router /api/users/{id}/profile-picture [put]
func (uc *UserController) updateUserProfilePictureHandler(c *gin.Context) {
userID := c.Param("id")
fileHeader, err := c.FormFile("file")
if err != nil {
_ = c.Error(err)
return
}
err := uc.userService.RequestOneTimeAccessEmail(input.Email, input.RedirectPath)
file, err := fileHeader.Open()
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
defer file.Close()
if err := uc.userService.UpdateProfilePicture(userID, file); err != nil {
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// updateCurrentUserProfilePictureHandler godoc
// @Summary Update current user's profile picture
// @Description Update the currently authenticated user's profile picture
// @Tags Users
// @Accept multipart/form-data
// @Produce json
// @Param file formData file true "Profile picture image file (PNG, JPG, or JPEG)"
// @Success 204 "No Content"
// @Router /api/users/me/profile-picture [put]
func (uc *UserController) updateCurrentUserProfilePictureHandler(c *gin.Context) {
userID := c.GetString("userID")
fileHeader, err := c.FormFile("file")
if err != nil {
_ = c.Error(err)
return
}
file, err := fileHeader.Open()
if err != nil {
_ = c.Error(err)
return
}
defer file.Close()
if err := uc.userService.UpdateProfilePicture(userID, file); err != nil {
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context, own bool) {
var input dto.OneTimeAccessTokenCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
_ = c.Error(err)
return
}
if own {
input.UserID = c.GetString("userID")
}
token, err := uc.userService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
if err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusCreated, gin.H{"token": token})
}
// createOwnOneTimeAccessTokenHandler godoc
// @Summary Create one-time access token for current user
// @Description Generate a one-time access token for the currently authenticated user
// @Tags Users
// @Param id path string true "User ID"
// @Param body body dto.OneTimeAccessTokenCreateDto true "Token options"
// @Success 201 {object} object "{ \"token\": \"string\" }"
// @Router /api/users/{id}/one-time-access-token [post]
func (uc *UserController) createOwnOneTimeAccessTokenHandler(c *gin.Context) {
uc.createOneTimeAccessTokenHandler(c, true)
}
func (uc *UserController) createAdminOneTimeAccessTokenHandler(c *gin.Context) {
uc.createOneTimeAccessTokenHandler(c, false)
}
func (uc *UserController) requestOneTimeAccessEmailHandler(c *gin.Context) {
var input dto.OneTimeAccessEmailDto
if err := c.ShouldBindJSON(&input); err != nil {
_ = c.Error(err)
return
}
err := uc.userService.RequestOneTimeAccessEmail(input.Email, input.RedirectPath)
if err != nil {
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// exchangeOneTimeAccessTokenHandler godoc
// @Summary Exchange one-time access token
// @Description Exchange a one-time access token for a session token
// @Tags Users
// @Param token path string true "One-time access token"
// @Success 200 {object} dto.UserDto
// @Router /api/one-time-access-token/{token} [post]
func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.ExchangeOneTimeAccessToken(c.Param("token"), c.ClientIP(), c.Request.UserAgent())
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
sessionDurationInMinutesParsed, _ := strconv.Atoi(uc.appConfigService.DbConfig.SessionDuration.Value)
maxAge := sessionDurationInMinutesParsed * 60
maxAge := int(uc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds())
cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto)
}
// getSetupAccessTokenHandler godoc
// @Summary Setup initial admin
// @Description Generate setup access token for initial admin user configuration
// @Tags Users
// @Success 200 {object} dto.UserDto
// @Router /api/one-time-access-token/setup [post]
func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
user, token, err := uc.userService.SetupInitialAdmin()
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
sessionDurationInMinutesParsed, _ := strconv.Atoi(uc.appConfigService.DbConfig.SessionDuration.Value)
maxAge := sessionDurationInMinutesParsed * 60
maxAge := int(uc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds())
cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto)
}
// updateUserGroups godoc
// @Summary Update user groups
// @Description Update the groups a specific user belongs to
// @Tags Users
// @Param id path string true "User ID"
// @Param groups body dto.UserUpdateUserGroupDto true "User group IDs"
// @Success 200 {object} dto.UserDto
// @Router /api/users/{id}/user-groups [put]
func (uc *UserController) updateUserGroups(c *gin.Context) {
var input dto.UserUpdateUserGroupDto
if err := c.ShouldBindJSON(&input); err != nil {
_ = c.Error(err)
return
}
user, err := uc.userService.UpdateUserGroups(c.Param("id"), input.UserGroupIds)
if err != nil {
_ = c.Error(err)
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, userDto)
}
// updateUser is an internal helper method, not exposed as an API endpoint
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
var input dto.UserCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
@@ -230,15 +468,52 @@ func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
user, err := uc.userService.UpdateUser(userID, input, updateOwnUser, false)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, userDto)
}
// resetUserProfilePictureHandler godoc
// @Summary Reset user profile picture
// @Description Reset a specific user's profile picture to the default
// @Tags Users
// @Produce json
// @Param id path string true "User ID"
// @Success 204 "No Content"
// @Router /api/users/{id}/profile-picture [delete]
func (uc *UserController) resetUserProfilePictureHandler(c *gin.Context) {
userID := c.Param("id")
if err := uc.userService.ResetProfilePicture(userID); err != nil {
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// resetCurrentUserProfilePictureHandler godoc
// @Summary Reset current user's profile picture
// @Description Reset the currently authenticated user's profile picture to the default
// @Tags Users
// @Produce json
// @Success 204 "No Content"
// @Router /api/users/me/profile-picture [delete]
func (uc *UserController) resetCurrentUserProfilePictureHandler(c *gin.Context) {
userID := c.GetString("userID")
if err := uc.userService.ResetProfilePicture(userID); err != nil {
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}

View File

@@ -10,144 +10,215 @@ import (
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
func NewUserGroupController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, userGroupService *service.UserGroupService) {
// NewUserGroupController creates a new controller for user group management
// @Summary User group management controller
// @Description Initializes all user group-related API endpoints
// @Tags User Groups
func NewUserGroupController(group *gin.RouterGroup, authMiddleware *middleware.AuthMiddleware, userGroupService *service.UserGroupService) {
ugc := UserGroupController{
UserGroupService: userGroupService,
}
group.GET("/user-groups", jwtAuthMiddleware.Add(true), ugc.list)
group.GET("/user-groups/:id", jwtAuthMiddleware.Add(true), ugc.get)
group.POST("/user-groups", jwtAuthMiddleware.Add(true), ugc.create)
group.PUT("/user-groups/:id", jwtAuthMiddleware.Add(true), ugc.update)
group.DELETE("/user-groups/:id", jwtAuthMiddleware.Add(true), ugc.delete)
group.PUT("/user-groups/:id/users", jwtAuthMiddleware.Add(true), ugc.updateUsers)
userGroupsGroup := group.Group("/user-groups")
userGroupsGroup.Use(authMiddleware.Add())
{
userGroupsGroup.GET("", ugc.list)
userGroupsGroup.GET("/:id", ugc.get)
userGroupsGroup.POST("", ugc.create)
userGroupsGroup.PUT("/:id", ugc.update)
userGroupsGroup.DELETE("/:id", ugc.delete)
userGroupsGroup.PUT("/:id/users", ugc.updateUsers)
}
}
type UserGroupController struct {
UserGroupService *service.UserGroupService
}
// list godoc
// @Summary List user groups
// @Description Get a paginated list of user groups with optional search and sorting
// @Tags User Groups
// @Param search query string false "Search term to filter user groups by name"
// @Param page query int false "Page number, starting from 1" default(1)
// @Param limit query int false "Number of items per page" default(10)
// @Param sort_column query string false "Column to sort by" default("name")
// @Param sort_direction query string false "Sort direction (asc or desc)" default("asc")
// @Success 200 {object} dto.Paginated[dto.UserGroupDtoWithUserCount]
// @Router /api/user-groups [get]
func (ugc *UserGroupController) list(c *gin.Context) {
searchTerm := c.Query("search")
var sortedPaginationRequest utils.SortedPaginationRequest
if err := c.ShouldBindQuery(&sortedPaginationRequest); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
groups, pagination, err := ugc.UserGroupService.List(searchTerm, sortedPaginationRequest)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
// Map the user groups to DTOs. The user count can't be mapped directly, so we have to do it manually.
// Map the user groups to DTOs
var groupsDto = make([]dto.UserGroupDtoWithUserCount, len(groups))
for i, group := range groups {
var groupDto dto.UserGroupDtoWithUserCount
if err := dto.MapStruct(group, &groupDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
groupDto.UserCount, err = ugc.UserGroupService.GetUserCountOfGroup(group.ID)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
groupsDto[i] = groupDto
}
c.JSON(http.StatusOK, gin.H{
"data": groupsDto,
"pagination": pagination,
c.JSON(http.StatusOK, dto.Paginated[dto.UserGroupDtoWithUserCount]{
Data: groupsDto,
Pagination: pagination,
})
}
// get godoc
// @Summary Get user group by ID
// @Description Retrieve detailed information about a specific user group including its users
// @Tags User Groups
// @Accept json
// @Produce json
// @Param id path string true "User Group ID"
// @Success 200 {object} dto.UserGroupDtoWithUsers
// @Security BearerAuth
// @Router /api/user-groups/{id} [get]
func (ugc *UserGroupController) get(c *gin.Context) {
group, err := ugc.UserGroupService.Get(c.Param("id"))
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, groupDto)
}
// create godoc
// @Summary Create user group
// @Description Create a new user group
// @Tags User Groups
// @Accept json
// @Produce json
// @Param userGroup body dto.UserGroupCreateDto true "User group information"
// @Success 201 {object} dto.UserGroupDtoWithUsers "Created user group"
// @Security BearerAuth
// @Router /api/user-groups [post]
func (ugc *UserGroupController) create(c *gin.Context) {
var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
group, err := ugc.UserGroupService.Create(input)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusCreated, groupDto)
}
// update godoc
// @Summary Update user group
// @Description Update an existing user group by ID
// @Tags User Groups
// @Accept json
// @Produce json
// @Param id path string true "User Group ID"
// @Param userGroup body dto.UserGroupCreateDto true "User group information"
// @Success 200 {object} dto.UserGroupDtoWithUsers "Updated user group"
// @Security BearerAuth
// @Router /api/user-groups/{id} [put]
func (ugc *UserGroupController) update(c *gin.Context) {
var input dto.UserGroupCreateDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
group, err := ugc.UserGroupService.Update(c.Param("id"), input, false)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, groupDto)
}
// delete godoc
// @Summary Delete user group
// @Description Delete a specific user group by ID
// @Tags User Groups
// @Accept json
// @Produce json
// @Param id path string true "User Group ID"
// @Success 204 "No Content"
// @Security BearerAuth
// @Router /api/user-groups/{id} [delete]
func (ugc *UserGroupController) delete(c *gin.Context) {
if err := ugc.UserGroupService.Delete(c.Param("id")); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.Status(http.StatusNoContent)
}
// updateUsers godoc
// @Summary Update users in a group
// @Description Update the list of users belonging to a specific user group
// @Tags User Groups
// @Accept json
// @Produce json
// @Param id path string true "User Group ID"
// @Param users body dto.UserGroupUpdateUsersDto true "List of user IDs to assign to this group"
// @Success 200 {object} dto.UserGroupDtoWithUsers
// @Security BearerAuth
// @Router /api/user-groups/{id}/users [put]
func (ugc *UserGroupController) updateUsers(c *gin.Context) {
var input dto.UserGroupUpdateUsersDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input)
group, err := ugc.UserGroupService.UpdateUsers(c.Param("id"), input.UserIDs)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var groupDto dto.UserGroupDtoWithUsers
if err := dto.MapStruct(group, &groupDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}

View File

@@ -2,7 +2,6 @@ package controller
import (
"net/http"
"strconv"
"time"
"github.com/go-webauthn/webauthn/protocol"
@@ -16,19 +15,19 @@ import (
"golang.org/x/time/rate"
)
func NewWebauthnController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, webauthnService *service.WebAuthnService, appConfigService *service.AppConfigService) {
func NewWebauthnController(group *gin.RouterGroup, authMiddleware *middleware.AuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, webauthnService *service.WebAuthnService, appConfigService *service.AppConfigService) {
wc := &WebauthnController{webAuthnService: webauthnService, appConfigService: appConfigService}
group.GET("/webauthn/register/start", jwtAuthMiddleware.Add(false), wc.beginRegistrationHandler)
group.POST("/webauthn/register/finish", jwtAuthMiddleware.Add(false), wc.verifyRegistrationHandler)
group.GET("/webauthn/register/start", authMiddleware.WithAdminNotRequired().Add(), wc.beginRegistrationHandler)
group.POST("/webauthn/register/finish", authMiddleware.WithAdminNotRequired().Add(), wc.verifyRegistrationHandler)
group.GET("/webauthn/login/start", wc.beginLoginHandler)
group.POST("/webauthn/login/finish", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), wc.verifyLoginHandler)
group.POST("/webauthn/logout", jwtAuthMiddleware.Add(false), wc.logoutHandler)
group.POST("/webauthn/logout", authMiddleware.WithAdminNotRequired().Add(), wc.logoutHandler)
group.GET("/webauthn/credentials", jwtAuthMiddleware.Add(false), wc.listCredentialsHandler)
group.PATCH("/webauthn/credentials/:id", jwtAuthMiddleware.Add(false), wc.updateCredentialHandler)
group.DELETE("/webauthn/credentials/:id", jwtAuthMiddleware.Add(false), wc.deleteCredentialHandler)
group.GET("/webauthn/credentials", authMiddleware.WithAdminNotRequired().Add(), wc.listCredentialsHandler)
group.PATCH("/webauthn/credentials/:id", authMiddleware.WithAdminNotRequired().Add(), wc.updateCredentialHandler)
group.DELETE("/webauthn/credentials/:id", authMiddleware.WithAdminNotRequired().Add(), wc.deleteCredentialHandler)
}
type WebauthnController struct {
@@ -40,7 +39,7 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
userID := c.GetString("userID")
options, err := wc.webAuthnService.BeginRegistration(userID)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
@@ -51,20 +50,20 @@ func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
sessionID, err := c.Cookie(cookie.SessionIdCookieName)
if err != nil {
c.Error(&common.MissingSessionIdError{})
_ = c.Error(&common.MissingSessionIdError{})
return
}
userID := c.GetString("userID")
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var credentialDto dto.WebauthnCredentialDto
if err := dto.MapStruct(credential, &credentialDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
@@ -74,7 +73,7 @@ func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
options, err := wc.webAuthnService.BeginLogin()
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
@@ -85,30 +84,29 @@ func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
sessionID, err := c.Cookie(cookie.SessionIdCookieName)
if err != nil {
c.Error(&common.MissingSessionIdError{})
_ = c.Error(&common.MissingSessionIdError{})
return
}
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
user, token, err := wc.webAuthnService.VerifyLogin(sessionID, credentialAssertionData, c.ClientIP(), c.Request.UserAgent())
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var userDto dto.UserDto
if err := dto.MapStruct(user, &userDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
sessionDurationInMinutesParsed, _ := strconv.Atoi(wc.appConfigService.DbConfig.SessionDuration.Value)
maxAge := sessionDurationInMinutesParsed * 60
maxAge := int(wc.appConfigService.DbConfig.SessionDuration.AsDurationMinutes().Seconds())
cookie.AddAccessTokenCookie(c, maxAge, token)
c.JSON(http.StatusOK, userDto)
@@ -118,13 +116,13 @@ func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
userID := c.GetString("userID")
credentials, err := wc.webAuthnService.ListCredentials(userID)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var credentialDtos []dto.WebauthnCredentialDto
if err := dto.MapStructList(credentials, &credentialDtos); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
@@ -137,7 +135,7 @@ func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
err := wc.webAuthnService.DeleteCredential(userID, credentialID)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
@@ -150,19 +148,19 @@ func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
var input dto.WebauthnCredentialUpdateDto
if err := c.ShouldBindJSON(&input); err != nil {
c.Error(err)
_ = c.Error(err)
return
}
credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
var credentialDto dto.WebauthnCredentialDto
if err := dto.MapStruct(credential, &credentialDto); err != nil {
c.Error(err)
_ = c.Error(err)
return
}

View File

@@ -1,47 +1,86 @@
package controller
import (
"encoding/json"
"fmt"
"log"
"net/http"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/service"
)
// NewWellKnownController creates a new controller for OIDC discovery endpoints
// @Summary OIDC Discovery controller
// @Description Initializes OIDC discovery and JWKS endpoints
// @Tags Well Known
func NewWellKnownController(group *gin.RouterGroup, jwtService *service.JwtService) {
wkc := &WellKnownController{jwtService: jwtService}
// Pre-compute the OIDC configuration document, which is static
var err error
wkc.oidcConfig, err = wkc.computeOIDCConfiguration()
if err != nil {
log.Fatalf("Failed to pre-compute OpenID Connect configuration document: %v", err)
}
group.GET("/.well-known/jwks.json", wkc.jwksHandler)
group.GET("/.well-known/openid-configuration", wkc.openIDConfigurationHandler)
}
type WellKnownController struct {
jwtService *service.JwtService
oidcConfig []byte
}
// jwksHandler godoc
// @Summary Get JSON Web Key Set (JWKS)
// @Description Returns the JSON Web Key Set used for token verification
// @Tags Well Known
// @Produce json
// @Success 200 {object} object "{ \"keys\": []interface{} }"
// @Router /.well-known/jwks.json [get]
func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
jwk, err := wkc.jwtService.GetJWK()
jwks, err := wkc.jwtService.GetPublicJWKSAsJSON()
if err != nil {
c.Error(err)
_ = c.Error(err)
return
}
c.JSON(http.StatusOK, gin.H{"keys": []interface{}{jwk}})
c.Data(http.StatusOK, "application/json; charset=utf-8", jwks)
}
// openIDConfigurationHandler godoc
// @Summary Get OpenID Connect discovery configuration
// @Description Returns the OpenID Connect discovery document with endpoints and capabilities
// @Tags Well Known
// @Success 200 {object} object "OpenID Connect configuration"
// @Router /.well-known/openid-configuration [get]
func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) {
c.Data(http.StatusOK, "application/json; charset=utf-8", wkc.oidcConfig)
}
func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
appUrl := common.EnvConfig.AppURL
config := map[string]interface{}{
alg, err := wkc.jwtService.GetKeyAlg()
if err != nil {
return nil, fmt.Errorf("failed to get key algorithm: %w", err)
}
config := map[string]any{
"issuer": appUrl,
"authorization_endpoint": appUrl + "/authorize",
"token_endpoint": appUrl + "/api/oidc/token",
"userinfo_endpoint": appUrl + "/api/oidc/userinfo",
"end_session_endpoint": appUrl + "/api/oidc/end-session",
"jwks_uri": appUrl + "/.well-known/jwks.json",
"scopes_supported": []string{"openid", "profile", "email"},
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username"},
"grant_types_supported": []string{"authorization_code", "refresh_token"},
"scopes_supported": []string{"openid", "profile", "email", "groups"},
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
"response_types_supported": []string{"code", "id_token"},
"subject_types_supported": []string{"public"},
"id_token_signing_alg_values_supported": []string{"RS256"},
"id_token_signing_alg_values_supported": []string{alg.String()},
}
c.JSON(http.StatusOK, config)
return json.Marshal(config)
}

View File

@@ -0,0 +1,25 @@
package dto
import (
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
type ApiKeyCreateDto struct {
Name string `json:"name" binding:"required,min=3,max=50"`
Description string `json:"description"`
ExpiresAt datatype.DateTime `json:"expiresAt" binding:"required"`
}
type ApiKeyDto struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
ExpiresAt datatype.DateTime `json:"expiresAt"`
LastUsedAt *datatype.DateTime `json:"lastUsedAt"`
CreatedAt datatype.DateTime `json:"createdAt"`
}
type ApiKeyResponseDto struct {
ApiKey ApiKeyDto `json:"apiKey"`
Token string `json:"token"`
}

View File

@@ -21,7 +21,7 @@ type AppConfigUpdateDto struct {
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
SmtpUser string `json:"smtpUser"`
SmtpPassword string `json:"smtpPassword"`
SmtpTls string `json:"smtpTls"`
SmtpTls string `json:"smtpTls" binding:"required,oneof=none starttls tls"`
SmtpSkipCertVerify string `json:"smtpSkipCertVerify"`
LdapEnabled string `json:"ldapEnabled" binding:"required"`
LdapUrl string `json:"ldapUrl"`
@@ -36,6 +36,7 @@ type AppConfigUpdateDto struct {
LdapAttributeUserEmail string `json:"ldapAttributeUserEmail"`
LdapAttributeUserFirstName string `json:"ldapAttributeUserFirstName"`
LdapAttributeUserLastName string `json:"ldapAttributeUserLastName"`
LdapAttributeUserProfilePicture string `json:"ldapAttributeUserProfilePicture"`
LdapAttributeGroupMember string `json:"ldapAttributeGroupMember"`
LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"`
LdapAttributeGroupName string `json:"ldapAttributeGroupName"`

View File

@@ -6,6 +6,6 @@ type CustomClaimDto struct {
}
type CustomClaimCreateDto struct {
Key string `json:"key" binding:"required,claimKey"`
Key string `json:"key" binding:"required"`
Value string `json:"value" binding:"required"`
}

View File

@@ -40,13 +40,11 @@ func MapStruct[S any, D any](source S, destination *D) error {
}
func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
// Loop through the fields of the destination struct
for i := 0; i < destVal.NumField(); i++ {
destField := destVal.Field(i)
destFieldType := destVal.Type().Field(i)
if destFieldType.Anonymous {
// Recursively handle embedded structs
if err := mapStructInternal(sourceVal, destField); err != nil {
return err
}
@@ -55,63 +53,57 @@ func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
sourceField := sourceVal.FieldByName(destFieldType.Name)
// If the source field is valid and can be assigned to the destination field
if sourceField.IsValid() && destField.CanSet() {
// Handle direct assignment for simple types
if sourceField.Type() == destField.Type() {
destField.Set(sourceField)
} else if sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice {
// Handle slices
if sourceField.Type().Elem() == destField.Type().Elem() {
// Direct assignment for slices of primitive types or non-struct elements
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
newSlice.Index(j).Set(sourceField.Index(j))
}
destField.Set(newSlice)
} else if sourceField.Type().Elem().Kind() == reflect.Struct && destField.Type().Elem().Kind() == reflect.Struct {
// Recursively map slices of structs
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
// Get the element from both source and destination slice
sourceElem := sourceField.Index(j)
destElem := reflect.New(destField.Type().Elem()).Elem()
// Recursively map the struct elements
if err := mapStructInternal(sourceElem, destElem); err != nil {
return err
}
// Set the mapped element in the new slice
newSlice.Index(j).Set(destElem)
}
destField.Set(newSlice)
}
} else if sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct {
// Recursively map nested structs
if err := mapStructInternal(sourceField, destField); err != nil {
return err
}
} else {
// Type switch for specific type conversions
switch sourceField.Interface().(type) {
case datatype.DateTime:
// Convert datatype.DateTime to time.Time
if sourceField.Type() == reflect.TypeOf(datatype.DateTime{}) && destField.Type() == reflect.TypeOf(time.Time{}) {
dateValue := sourceField.Interface().(datatype.DateTime)
destField.Set(reflect.ValueOf(dateValue.ToTime()))
}
}
if err := mapField(sourceField, destField); err != nil {
return err
}
}
}
return nil
}
func mapField(sourceField reflect.Value, destField reflect.Value) error {
switch {
case sourceField.Type() == destField.Type():
destField.Set(sourceField)
case sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice:
return mapSlice(sourceField, destField)
case sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct:
return mapStructInternal(sourceField, destField)
default:
return mapSpecialTypes(sourceField, destField)
}
return nil
}
func mapSlice(sourceField reflect.Value, destField reflect.Value) error {
if sourceField.Type().Elem() == destField.Type().Elem() {
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
newSlice.Index(j).Set(sourceField.Index(j))
}
destField.Set(newSlice)
} else if sourceField.Type().Elem().Kind() == reflect.Struct && destField.Type().Elem().Kind() == reflect.Struct {
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
for j := 0; j < sourceField.Len(); j++ {
sourceElem := sourceField.Index(j)
destElem := reflect.New(destField.Type().Elem()).Elem()
if err := mapStructInternal(sourceElem, destElem); err != nil {
return err
}
newSlice.Index(j).Set(destElem)
}
destField.Set(newSlice)
}
return nil
}
func mapSpecialTypes(sourceField reflect.Value, destField reflect.Value) error {
if _, ok := sourceField.Interface().(datatype.DateTime); ok {
if sourceField.Type() == reflect.TypeOf(datatype.DateTime{}) && destField.Type() == reflect.TypeOf(time.Time{}) {
dateValue := sourceField.Interface().(datatype.DateTime)
destField.Set(reflect.ValueOf(dateValue.ToTime()))
}
}
return nil
}

View File

@@ -1,13 +1,13 @@
package dto
type PublicOidcClientDto struct {
type OidcClientMetaDataDto struct {
ID string `json:"id"`
Name string `json:"name"`
HasLogo bool `json:"hasLogo"`
}
type OidcClientDto struct {
PublicOidcClientDto
OidcClientMetaDataDto
CallbackURLs []string `json:"callbackURLs"`
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
IsPublic bool `json:"isPublic"`
@@ -15,12 +15,8 @@ type OidcClientDto struct {
}
type OidcClientWithAllowedUserGroupsDto struct {
PublicOidcClientDto
CallbackURLs []string `json:"callbackURLs"`
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
IsPublic bool `json:"isPublic"`
PkceEnabled bool `json:"pkceEnabled"`
AllowedUserGroups []UserGroupDtoWithUserCount `json:"allowedUserGroups"`
OidcClientDto
AllowedUserGroups []UserGroupDtoWithUserCount `json:"allowedUserGroups"`
}
type OidcClientCreateDto struct {
@@ -52,10 +48,11 @@ type AuthorizationRequiredDto struct {
type OidcCreateTokensDto struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"`
Code string `form:"code"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"`
RefreshToken string `form:"refresh_token"`
}
type OidcUpdateAllowedUserGroupsDto struct {
@@ -68,3 +65,11 @@ type OidcLogoutDto struct {
PostLogoutRedirectUri string `form:"post_logout_redirect_uri"`
State string `form:"state"`
}
type OidcTokenResponseDto struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IdToken string `json:"id_token,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
}

View File

@@ -0,0 +1,10 @@
package dto
import "github.com/pocket-id/pocket-id/backend/internal/utils"
type Pagination = utils.PaginationResponse
type Paginated[T any] struct {
Data []T `json:"data"`
Pagination Pagination `json:"pagination"`
}

View File

@@ -9,21 +9,24 @@ type UserDto struct {
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
IsAdmin bool `json:"isAdmin"`
Locale *string `json:"locale"`
CustomClaims []CustomClaimDto `json:"customClaims"`
UserGroups []UserGroupDto `json:"userGroups"`
LdapID *string `json:"ldapId"`
}
type UserCreateDto struct {
Username string `json:"username" binding:"required,username,min=2,max=50"`
Email string `json:"email" binding:"required,email"`
FirstName string `json:"firstName" binding:"required,min=1,max=50"`
LastName string `json:"lastName" binding:"required,min=1,max=50"`
IsAdmin bool `json:"isAdmin"`
LdapID string `json:"-"`
Username string `json:"username" binding:"required,username,min=2,max=50"`
Email string `json:"email" binding:"required,email"`
FirstName string `json:"firstName" binding:"required,min=1,max=50"`
LastName string `json:"lastName" binding:"required,min=1,max=50"`
IsAdmin bool `json:"isAdmin"`
Locale *string `json:"locale"`
LdapID string `json:"-"`
}
type OneTimeAccessTokenCreateDto struct {
UserID string `json:"userId" binding:"required"`
UserID string `json:"userId"`
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
}
@@ -31,3 +34,7 @@ type OneTimeAccessEmailDto struct {
Email string `json:"email" binding:"required,email"`
RedirectPath string `json:"redirectPath"`
}
type UserUpdateUserGroupDto struct {
UserGroupIds []string `json:"userGroupIds" binding:"required"`
}

View File

@@ -4,6 +4,15 @@ import (
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
)
type UserGroupDto struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`
Name string `json:"name"`
CustomClaims []CustomClaimDto `json:"customClaims"`
LdapID *string `json:"ldapId"`
CreatedAt datatype.DateTime `json:"createdAt"`
}
type UserGroupDtoWithUsers struct {
ID string `json:"id"`
FriendlyName string `json:"friendlyName"`

View File

@@ -1,10 +1,11 @@
package dto
import (
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
"log"
"regexp"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
)
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
@@ -16,22 +17,10 @@ var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
return matched
}
var validateClaimKey validator.Func = func(fl validator.FieldLevel) bool {
// The string can only contain letters and numbers
regex := "^[A-Za-z0-9]*$"
matched, _ := regexp.MatchString(regex, fl.Field().String())
return matched
}
func init() {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("username", validateUsername); err != nil {
log.Fatalf("Failed to register custom validation: %v", err)
}
}
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
if err := v.RegisterValidation("claimKey", validateClaimKey); err != nil {
log.Fatalf("Failed to register custom validation: %v", err)
}
}
}

View File

@@ -22,6 +22,8 @@ func RegisterDbCleanupJobs(db *gorm.DB) {
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
registerJob(scheduler, "ClearOidcRefreshTokens", "0 3 * * *", jobs.clearOidcRefreshTokens)
registerJob(scheduler, "ClearAuditLogs", "0 3 * * *", jobs.clearAuditLogs)
scheduler.Start()
}
@@ -44,6 +46,11 @@ func (j *Jobs) clearOidcAuthorizationCodes() error {
return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
}
// ClearOidcAuthorizationCodes deletes OIDC authorization codes that have expired
func (j *Jobs) clearOidcRefreshTokens() error {
return j.db.Delete(&model.OidcRefreshToken{}, "expires_at < ?", datatype.DateTime(time.Now())).Error
}
// ClearAuditLogs deletes audit logs older than 90 days
func (j *Jobs) clearAuditLogs() error {
return j.db.Delete(&model.AuditLog{}, "created_at < ?", datatype.DateTime(time.Now().AddDate(0, 0, -90))).Error

View File

@@ -32,7 +32,7 @@ func RegisterLdapJobs(ldapService *service.LdapService, appConfigService *servic
}
func (j *LdapJobs) syncLdap() error {
if j.appConfigService.DbConfig.LdapEnabled.Value == "true" {
if j.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return j.ldapService.SyncAll()
}
return nil

View File

@@ -0,0 +1,50 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/service"
)
type ApiKeyAuthMiddleware struct {
apiKeyService *service.ApiKeyService
jwtService *service.JwtService
}
func NewApiKeyAuthMiddleware(apiKeyService *service.ApiKeyService, jwtService *service.JwtService) *ApiKeyAuthMiddleware {
return &ApiKeyAuthMiddleware{
apiKeyService: apiKeyService,
jwtService: jwtService,
}
}
func (m *ApiKeyAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
return func(c *gin.Context) {
userID, isAdmin, err := m.Verify(c, adminRequired)
if err != nil {
c.Abort()
_ = c.Error(err)
return
}
c.Set("userID", userID)
c.Set("userIsAdmin", isAdmin)
c.Next()
}
}
func (m *ApiKeyAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (userID string, isAdmin bool, err error) {
apiKey := c.GetHeader("X-API-KEY")
user, err := m.apiKeyService.ValidateApiKey(apiKey)
if err != nil {
return "", false, &common.NotSignedInError{}
}
// Check if the user is an admin
if adminRequired && !user.IsAdmin {
return "", false, &common.MissingPermissionError{}
}
return user.ID, user.IsAdmin, nil
}

View File

@@ -0,0 +1,89 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/service"
)
// AuthMiddleware is a wrapper middleware that delegates to either API key or JWT authentication
type AuthMiddleware struct {
apiKeyMiddleware *ApiKeyAuthMiddleware
jwtMiddleware *JwtAuthMiddleware
options AuthOptions
}
type AuthOptions struct {
AdminRequired bool
SuccessOptional bool
}
func NewAuthMiddleware(
apiKeyService *service.ApiKeyService,
jwtService *service.JwtService,
) *AuthMiddleware {
return &AuthMiddleware{
apiKeyMiddleware: NewApiKeyAuthMiddleware(apiKeyService, jwtService),
jwtMiddleware: NewJwtAuthMiddleware(jwtService),
options: AuthOptions{
AdminRequired: true,
SuccessOptional: false,
},
}
}
// WithAdminNotRequired allows the middleware to continue with the request even if the user is not an admin
func (m *AuthMiddleware) WithAdminNotRequired() *AuthMiddleware {
// Create a new instance to avoid modifying the original
clone := &AuthMiddleware{
apiKeyMiddleware: m.apiKeyMiddleware,
jwtMiddleware: m.jwtMiddleware,
options: m.options,
}
clone.options.AdminRequired = false
return clone
}
// WithSuccessOptional allows the middleware to continue with the request even if authentication fails
func (m *AuthMiddleware) WithSuccessOptional() *AuthMiddleware {
// Create a new instance to avoid modifying the original
clone := &AuthMiddleware{
apiKeyMiddleware: m.apiKeyMiddleware,
jwtMiddleware: m.jwtMiddleware,
options: m.options,
}
clone.options.SuccessOptional = true
return clone
}
func (m *AuthMiddleware) Add() gin.HandlerFunc {
return func(c *gin.Context) {
// First try JWT auth
userID, isAdmin, err := m.jwtMiddleware.Verify(c, m.options.AdminRequired)
if err == nil {
// JWT auth succeeded, continue with the request
c.Set("userID", userID)
c.Set("userIsAdmin", isAdmin)
c.Next()
return
}
// JWT auth failed, try API key auth
userID, isAdmin, err = m.apiKeyMiddleware.Verify(c, m.options.AdminRequired)
if err == nil {
// API key auth succeeded, continue with the request
c.Set("userID", userID)
c.Set("userIsAdmin", isAdmin)
c.Next()
return
}
if m.options.SuccessOptional {
c.Next()
return
}
// Both JWT and API key auth failed
c.Abort()
_ = c.Error(err)
}
}

View File

@@ -1,6 +1,8 @@
package middleware
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/pocket-id/pocket-id/backend/internal/common"
)
@@ -23,7 +25,7 @@ func (m *CorsMiddleware) Add() gin.HandlerFunc {
c.Writer.Header().Set("Access-Control-Allow-Headers", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")
if c.Request.Method == "OPTIONS" {
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(204)
return
}

View File

@@ -19,7 +19,7 @@ func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc {
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
if err := c.Request.ParseMultipartForm(maxSize); err != nil {
err = &common.FileTooLargeError{MaxSize: formatFileSize(maxSize)}
c.Error(err)
_ = c.Error(err)
c.Abort()
return
}

View File

@@ -10,51 +10,59 @@ import (
)
type JwtAuthMiddleware struct {
jwtService *service.JwtService
ignoreUnauthenticated bool
jwtService *service.JwtService
}
func NewJwtAuthMiddleware(jwtService *service.JwtService, ignoreUnauthenticated bool) *JwtAuthMiddleware {
return &JwtAuthMiddleware{jwtService: jwtService, ignoreUnauthenticated: ignoreUnauthenticated}
func NewJwtAuthMiddleware(jwtService *service.JwtService) *JwtAuthMiddleware {
return &JwtAuthMiddleware{jwtService: jwtService}
}
func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
func (m *JwtAuthMiddleware) Add(adminRequired bool) gin.HandlerFunc {
return func(c *gin.Context) {
// Extract the token from the cookie or the Authorization header
token, err := c.Cookie(cookie.AccessTokenCookieName)
userID, isAdmin, err := m.Verify(c, adminRequired)
if err != nil {
authorizationHeaderSplitted := strings.Split(c.GetHeader("Authorization"), " ")
if len(authorizationHeaderSplitted) == 2 {
token = authorizationHeaderSplitted[1]
} else if m.ignoreUnauthenticated {
c.Next()
return
} else {
c.Error(&common.NotSignedInError{})
c.Abort()
return
}
}
claims, err := m.jwtService.VerifyAccessToken(token)
if err != nil && m.ignoreUnauthenticated {
c.Next()
return
} else if err != nil {
c.Error(&common.NotSignedInError{})
c.Abort()
_ = c.Error(err)
return
}
// Check if the user is an admin
if adminOnly && !claims.IsAdmin {
c.Error(&common.MissingPermissionError{})
c.Abort()
return
}
c.Set("userID", claims.Subject)
c.Set("userIsAdmin", claims.IsAdmin)
c.Set("userID", userID)
c.Set("userIsAdmin", isAdmin)
c.Next()
}
}
func (m *JwtAuthMiddleware) Verify(c *gin.Context, adminRequired bool) (subject string, isAdmin bool, err error) {
// Extract the token from the cookie
accessToken, err := c.Cookie(cookie.AccessTokenCookieName)
if err != nil {
// Try to extract the token from the Authorization header if it's not in the cookie
var ok bool
_, accessToken, ok = strings.Cut(c.GetHeader("Authorization"), " ")
if !ok || accessToken == "" {
return "", false, &common.NotSignedInError{}
}
}
token, err := m.jwtService.VerifyAccessToken(accessToken)
if err != nil {
return "", false, &common.NotSignedInError{}
}
subject, ok := token.Subject()
if !ok {
_ = c.Error(&common.TokenInvalidError{})
return
}
// Check if the user is an admin
isAdmin, err = service.GetIsAdmin(token)
if err != nil {
return "", false, &common.TokenInvalidError{}
}
if adminRequired && !isAdmin {
return "", false, &common.MissingPermissionError{}
}
return subject, isAdmin, nil
}

View File

@@ -36,7 +36,7 @@ func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
limiter := getLimiter(ip, limit, burst, &mu, clients)
if !limiter.Allow() {
c.Error(&common.TooManyRequestsError{})
_ = c.Error(&common.TooManyRequestsError{})
c.Abort()
return
}

View File

@@ -0,0 +1,16 @@
package model
import datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
type ApiKey struct {
Base
Name string `sortable:"true"`
Key string
Description *string
ExpiresAt datatype.DateTime `sortable:"true"`
LastUsedAt *datatype.DateTime `sortable:"true"`
UserID string
User User
}

View File

@@ -1,5 +1,10 @@
package model
import (
"strconv"
"time"
)
type AppConfigVariable struct {
Key string `gorm:"primaryKey;not null"`
Type string
@@ -9,6 +14,21 @@ type AppConfigVariable struct {
DefaultValue string
}
// IsTrue returns true if the value is a truthy string, such as "true", "t", "yes", "1", etc.
func (a *AppConfigVariable) IsTrue() bool {
ok, _ := strconv.ParseBool(a.Value)
return ok
}
// AsDurationMinutes returns the value as a time.Duration, interpreting the string as a whole number of minutes.
func (a *AppConfigVariable) AsDurationMinutes() time.Duration {
val, err := strconv.Atoi(a.Value)
if err != nil {
return 0
}
return time.Duration(val) * time.Minute
}
type AppConfig struct {
// General
AppName AppConfigVariable
@@ -43,6 +63,7 @@ type AppConfig struct {
LdapAttributeUserEmail AppConfigVariable
LdapAttributeUserFirstName AppConfigVariable
LdapAttributeUserLastName AppConfigVariable
LdapAttributeUserProfilePicture AppConfigVariable
LdapAttributeGroupMember AppConfigVariable
LdapAttributeGroupUniqueIdentifier AppConfigVariable
LdapAttributeGroupName AppConfigVariable

View File

@@ -0,0 +1,60 @@
package model
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestAppConfigVariable_AsMinutesDuration(t *testing.T) {
tests := []struct {
name string
value string
expected time.Duration
expectedSeconds int
}{
{
name: "valid positive integer",
value: "60",
expected: 60 * time.Minute,
expectedSeconds: 3600,
},
{
name: "valid zero integer",
value: "0",
expected: 0,
expectedSeconds: 0,
},
{
name: "negative integer",
value: "-30",
expected: -30 * time.Minute,
expectedSeconds: -1800,
},
{
name: "invalid non-integer",
value: "not-a-number",
expected: 0,
expectedSeconds: 0,
},
{
name: "empty string",
value: "",
expected: 0,
expectedSeconds: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configVar := AppConfigVariable{
Value: tt.value,
}
result := configVar.AsDurationMinutes()
assert.Equal(t, tt.expected, result)
assert.Equal(t, tt.expectedSeconds, int(result.Seconds()))
})
}
}

View File

@@ -18,9 +18,9 @@ type AuditLog struct {
Data AuditLogData
}
type AuditLogData map[string]string
type AuditLogData map[string]string //nolint:recvcheck
type AuditLogEvent string
type AuditLogEvent string //nolint:recvcheck
const (
AuditLogEventSignIn AuditLogEvent = "SIGN_IN"

View File

@@ -4,20 +4,20 @@ import (
"time"
"github.com/google/uuid"
model "github.com/pocket-id/pocket-id/backend/internal/model/types"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"gorm.io/gorm"
)
// Base contains common columns for all tables.
type Base struct {
ID string `gorm:"primaryKey;not null"`
CreatedAt model.DateTime `sortable:"true"`
ID string `gorm:"primaryKey;not null"`
CreatedAt datatype.DateTime `sortable:"true"`
}
func (b *Base) BeforeCreate(_ *gorm.DB) (err error) {
if b.ID == "" {
b.ID = uuid.New().String()
}
b.CreatedAt = model.DateTime(time.Now())
b.CreatedAt = datatype.DateTime(time.Now())
return
}

View File

@@ -51,13 +51,27 @@ type OidcClient struct {
CreatedBy User
}
type OidcRefreshToken struct {
Base
Token string
ExpiresAt datatype.DateTime
Scope string
UserID string
User User
ClientID string
Client OidcClient
}
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
// Compute HasLogo field
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
return nil
}
type UrlList []string
type UrlList []string //nolint:recvcheck
func (cu *UrlList) Scan(value interface{}) error {
if v, ok := value.([]byte); ok {

View File

@@ -8,7 +8,7 @@ import (
)
// DateTime custom type for time.Time to store date as unix timestamp for sqlite and as date for postgres
type DateTime time.Time
type DateTime time.Time //nolint:recvcheck
func (date *DateTime) Scan(value interface{}) (err error) {
*date = DateTime(value.(time.Time))

View File

@@ -14,6 +14,7 @@ type User struct {
FirstName string `sortable:"true"`
LastName string `sortable:"true"`
IsAdmin bool `sortable:"true"`
Locale *string
LdapID *string
CustomClaims []CustomClaim

View File

@@ -45,7 +45,7 @@ type PublicKeyCredentialRequestOptions struct {
Timeout time.Duration
}
type AuthenticatorTransportList []protocol.AuthenticatorTransport
type AuthenticatorTransportList []protocol.AuthenticatorTransport //nolint:recvcheck
// Scan and Value methods for GORM to handle the custom type
func (atl *AuthenticatorTransportList) Scan(value interface{}) error {

View File

@@ -0,0 +1,103 @@
package service
import (
"errors"
"log"
"time"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
)
type ApiKeyService struct {
db *gorm.DB
}
func NewApiKeyService(db *gorm.DB) *ApiKeyService {
return &ApiKeyService{db: db}
}
func (s *ApiKeyService) ListApiKeys(userID string, sortedPaginationRequest utils.SortedPaginationRequest) ([]model.ApiKey, utils.PaginationResponse, error) {
query := s.db.Where("user_id = ?", userID).Model(&model.ApiKey{})
var apiKeys []model.ApiKey
pagination, err := utils.PaginateAndSort(sortedPaginationRequest, query, &apiKeys)
if err != nil {
return nil, utils.PaginationResponse{}, err
}
return apiKeys, pagination, nil
}
func (s *ApiKeyService) CreateApiKey(userID string, input dto.ApiKeyCreateDto) (model.ApiKey, string, error) {
// Check if expiration is in the future
if !input.ExpiresAt.ToTime().After(time.Now()) {
return model.ApiKey{}, "", &common.APIKeyExpirationDateError{}
}
// Generate a secure random API key
token, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return model.ApiKey{}, "", err
}
apiKey := model.ApiKey{
Name: input.Name,
Key: utils.CreateSha256Hash(token), // Hash the token for storage
Description: &input.Description,
ExpiresAt: datatype.DateTime(input.ExpiresAt),
UserID: userID,
}
if err := s.db.Create(&apiKey).Error; err != nil {
return model.ApiKey{}, "", err
}
// Return the raw token only once - it cannot be retrieved later
return apiKey, token, nil
}
func (s *ApiKeyService) RevokeApiKey(userID, apiKeyID string) error {
var apiKey model.ApiKey
if err := s.db.Where("id = ? AND user_id = ?", apiKeyID, userID).First(&apiKey).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return &common.APIKeyNotFoundError{}
}
return err
}
return s.db.Delete(&apiKey).Error
}
func (s *ApiKeyService) ValidateApiKey(apiKey string) (model.User, error) {
if apiKey == "" {
return model.User{}, &common.NoAPIKeyProvidedError{}
}
var key model.ApiKey
hashedKey := utils.CreateSha256Hash(apiKey)
if err := s.db.Preload("User").Where("key = ? AND expires_at > ?",
hashedKey, datatype.DateTime(time.Now())).Preload("User").First(&key).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return model.User{}, &common.InvalidAPIKeyError{}
}
return model.User{}, err
}
// Update last used time
now := datatype.DateTime(time.Now())
key.LastUsedAt = &now
if err := s.db.Save(&key).Error; err != nil {
log.Printf("Failed to update last used time: %v", err)
}
return key.User, nil
}

View File

@@ -27,6 +27,7 @@ func NewAppConfigService(db *gorm.DB) *AppConfigService {
if err := service.InitDbConfig(); err != nil {
log.Fatalf("Failed to initialize app config service: %v", err)
}
return service
}
@@ -96,8 +97,8 @@ var defaultDbConfig = model.AppConfig{
},
SmtpTls: model.AppConfigVariable{
Key: "smtpTls",
Type: "bool",
DefaultValue: "true",
Type: "string",
DefaultValue: "none",
},
SmtpSkipCertVerify: model.AppConfigVariable{
Key: "smtpSkipCertVerify",
@@ -173,6 +174,10 @@ var defaultDbConfig = model.AppConfig{
Key: "ldapAttributeUserLastName",
Type: "string",
},
LdapAttributeUserProfilePicture: model.AppConfigVariable{
Key: "ldapAttributeUserProfilePicture",
Type: "string",
},
LdapAttributeGroupMember: model.AppConfigVariable{
Key: "ldapAttributeGroupMember",
Type: "string",

View File

@@ -60,7 +60,7 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ipAddress, userAgent, userID
}
// If the user hasn't logged in from the same device before and email notifications are enabled, send an email
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.Value == "true" && count <= 1 {
if s.appConfigService.DbConfig.EmailLoginNotificationEnabled.IsTrue() && count <= 1 {
go func() {
var user model.User
s.db.Where("id = ?", userID).First(&user)

View File

@@ -105,9 +105,10 @@ func (s *CustomClaimService) updateCustomClaims(idType idType, value string, cla
Value: claim.Value,
}
if idType == UserID {
switch idType {
case UserID:
customClaim.UserID = &value
} else if idType == UserGroupID {
case UserGroupID:
customClaim.UserGroupID = &value
}

View File

@@ -1,10 +1,11 @@
//go:build e2etest
package service
import (
"crypto/ecdsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"log"
"os"
@@ -12,14 +13,15 @@ import (
"time"
"github.com/fxamacker/cbor/v2"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/resources"
"github.com/go-webauthn/webauthn/protocol"
"github.com/lestrrat-go/jwx/v3/jwk"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/utils"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/resources"
)
type TestService struct {
@@ -152,6 +154,17 @@ func (s *TestService) SeedDatabase() error {
return err
}
refreshToken := model.OidcRefreshToken{
Token: utils.CreateSha256Hash("ou87UDg249r1StBLYkMEqy9TXDbV5HmGuDpMcZDo"),
ExpiresAt: datatype.DateTime(time.Now().Add(24 * time.Hour)),
Scope: "openid profile email",
UserID: users[0].ID,
ClientID: oidcClients[0].ID,
}
if err := tx.Create(&refreshToken).Error; err != nil {
return err
}
accessToken := model.OneTimeAccessToken{
Token: "one-time-token",
ExpiresAt: datatype.DateTime(time.Now().Add(1 * time.Hour)),
@@ -212,6 +225,18 @@ func (s *TestService) SeedDatabase() error {
return err
}
apiKey := model.ApiKey{
Base: model.Base{
ID: "5f1fa856-c164-4295-961e-175a0d22d725",
},
Name: "Test API Key",
Key: "6c34966f57ef2bb7857649aff0e7ab3ad67af93c846342ced3f5a07be8706c20",
UserID: users[0].ID,
}
if err := tx.Create(&apiKey).Error; err != nil {
return err
}
return nil
})
}
@@ -292,40 +317,10 @@ func (s *TestService) ResetAppConfig() error {
}
func (s *TestService) SetJWTKeys() {
privateKeyString := `-----BEGIN RSA PRIVATE KEY-----
MIIEpQIBAAKCAQEAyaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B
83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC+585UXacoJ0c
hUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl/4EDDTO8HwawTjwkPo
QlRzeByhlvGPVvwgB3Fn93B8QJ/cZhXKxJvjjrC/8Pk76heC/ntEMru71Ix77BoC
3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeO
Zl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJwIDAQABAoIBAQCa8wNZJ08+9y6b
RzSIQcTaBuq1XY0oyYvCuX0ToruDyVNX3lJ48udb9vDIw9XsQans9CTeXXsjldGE
WPN7sapOcUg6ArMyJqc+zuO/YQu0EwYrTE48BOC7WIZvvTFnq9y+4R9HJjd0nTOv
iOlR1W5fAqbH2srgh1mfZ0UIp+9K6ymoinPXVGEXUAuuoMuTEZW/tnA2HT9WEllT
2FyMbmXrFzutAQqk9GRmnQh2OQZLxnQWyShVqJEhYBtm6JUUH1YJbyTVzMLgdBM8
ukgjTVtRDHaW51ubRSVdGBVT2m1RRtTsYAiZCpM5bwt88aSUS9yDOUiVH+irDg/3
IHEuL7IxAoGBAP2MpXPXtOwinajUQ9hKLDAtpq4axGvY+aGP5dNEMsuPo5ggOfUP
b4sqr73kaNFO3EbxQOQVoFjehhi4dQxt1/kAala9HZ5N7s26G2+eUWFF8jy7gWSN
qusNqGrG4g8D3WOyqZFb/x/m6SE0Jcg7zvIYbnAOq1Fexeik0Fc/DNzLAoGBAMua
d4XIfu4ydtU5AIaf1ZNXywgLg+LWxK8ELNqH/Y2vLAeIiTrOVp+hw9z+zHPD5cnu
6mix783PCOYNLTylrwtAz3fxSz14lsDFQM3ntzVF/6BniTTkKddctcPyqnTvamah
0hD2dzXBS/0mTBYIIMYTNbs0Yj87FTdJZw/+qa2VAoGBAKbzQkp54W6PCIMPabD0
fg4nMRZ5F5bv4seIKcunn068QPs9VQxQ4qCfNeLykDYqGA86cgD9YHzD4UZLxv6t
IUWbCWod0m/XXwPlpIUlmO5VEUD+MiAUzFNDxf6xAE7ku5UXImJNUjseX6l2Xd5v
yz9L6QQuFI5aujQKugiIwp5rAoGATtUVGCCkPNgfOLmkYXu7dxxUCV5kB01+xAEK
2OY0n0pG8vfDophH4/D/ZC7nvJ8J9uDhs/3JStexq1lIvaWtG99RNTChIEDzpdn6
GH9yaVcb/eB4uJjrNm64FhF8PGCCwxA+xMCZMaARKwhMB2/IOMkxUbWboL3gnhJ2
rDO/QO0CgYEA2Grt6uXHm61ji3xSdkBWNtUnj19vS1+7rFJp5SoYztVQVThf/W52
BAiXKBdYZDRVoItC/VS2NvAOjeJjhYO/xQ/q3hK7MdtuXfEPpLnyXKkmWo3lrJ26
wbeF6l05LexCkI7ShsOuSt+dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI=
-----END RSA PRIVATE KEY-----
`
const privateKeyString = `{"alg":"RS256","d":"mvMDWSdPPvcum0c0iEHE2gbqtV2NKMmLwrl9E6K7g8lTV95SePLnW_bwyMPV7EGp7PQk3l17I5XRhFjze7GqTnFIOgKzMianPs7jv2ELtBMGK0xOPATgu1iGb70xZ6vcvuEfRyY3dJ0zr4jpUdVuXwKmx9rK4IdZn2dFCKfvSuspqIpz11RhF1ALrqDLkxGVv7ZwNh0_VhJZU9hcjG5l6xc7rQEKpPRkZp0IdjkGS8Z0FskoVaiRIWAbZuiVFB9WCW8k1czC4HQTPLpII01bUQx2ludbm0UlXRgVU9ptUUbU7GAImQqTOW8LfPGklEvcgzlIlR_oqw4P9yBxLi-yMQ","dp":"pvNCSnnhbo8Igw9psPR-DicxFnkXlu_ix4gpy6efTrxA-z1VDFDioJ814vKQNioYDzpyAP1gfMPhRkvG_q0hRZsJah3Sb9dfA-WkhSWY7lURQP4yIBTMU0PF_rEATuS7lRciYk1SOx5fqXZd3m_LP0vpBC4Ujlq6NAq6CIjCnms","dq":"TtUVGCCkPNgfOLmkYXu7dxxUCV5kB01-xAEK2OY0n0pG8vfDophH4_D_ZC7nvJ8J9uDhs_3JStexq1lIvaWtG99RNTChIEDzpdn6GH9yaVcb_eB4uJjrNm64FhF8PGCCwxA-xMCZMaARKwhMB2_IOMkxUbWboL3gnhJ2rDO_QO0","e":"AQAB","kid":"8uHDw3M6rf8","kty":"RSA","n":"yaeEL0VKoPBXIAaWXsUgmu05lAvEIIdJn0FX9lHh4JE5UY9B83C5sCNdhs9iSWzpeP11EVjWp8i3Yv2CF7c7u50BXnVBGtxpZpFC-585UXacoJ0chUmarL9GRFJcM1nPHBTFu68aRrn1rIKNHUkNaaxFo0NFGl_4EDDTO8HwawTjwkPoQlRzeByhlvGPVvwgB3Fn93B8QJ_cZhXKxJvjjrC_8Pk76heC_ntEMru71Ix77BoC3j2TuyiN7m9RNBW8BU5q6lKoIdvIeZfTFLzi37iufyfvMrJTixp9zhNB1NxlLCeOZl2MXegtiGqd2H3cbAyqoOiv9ihUWTfXj7SxJw","p":"_Yylc9e07CKdqNRD2EosMC2mrhrEa9j5oY_l00Qyy4-jmCA59Q9viyqvveRo0U7cRvFA5BWgWN6GGLh1DG3X-QBqVr0dnk3uzbobb55RYUXyPLuBZI2q6w2oasbiDwPdY7KpkVv_H-bpITQlyDvO8hhucA6rUV7F6KTQVz8M3Ms","q":"y5p3hch-7jJ21TkAhp_Vk1fLCAuD4tbErwQs2of9ja8sB4iJOs5Wn6HD3P7Mc8Plye7qaLHvzc8I5g0tPKWvC0DPd_FLPXiWwMVAzee3NUX_oGeJNOQp11y1w_KqdO9qZqHSEPZ3NcFL_SZMFgggxhM1uzRiPzsVN0lnD_6prZU","qi":"2Grt6uXHm61ji3xSdkBWNtUnj19vS1-7rFJp5SoYztVQVThf_W52BAiXKBdYZDRVoItC_VS2NvAOjeJjhYO_xQ_q3hK7MdtuXfEPpLnyXKkmWo3lrJ26wbeF6l05LexCkI7ShsOuSt-dsyaTJTszuKDIA6YOfWvfo3aVZmlWRaI","use":"sig"}`
block, _ := pem.Decode([]byte(privateKeyString))
privateKey, _ := x509.ParsePKCS1PrivateKey(block.Bytes)
s.jwtService.PrivateKey = privateKey
s.jwtService.PublicKey = &privateKey.PublicKey
privateKey, _ := jwk.ParseKey([]byte(privateKeyString))
s.jwtService.SetKey(privateKey)
}
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key

View File

@@ -3,27 +3,26 @@ package service
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
htemplate "html/template"
"mime/multipart"
"mime/quotedprintable"
"net"
"net/smtp"
"net/textproto"
"os"
"strings"
ttemplate "text/template"
"time"
"github.com/emersion/go-sasl"
"github.com/emersion/go-smtp"
"github.com/google/uuid"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils/email"
"gorm.io/gorm"
)
var netDialer = &net.Dialer{
Timeout: 3 * time.Second,
}
type EmailService struct {
appConfigService *AppConfigService
db *gorm.DB
@@ -88,6 +87,29 @@ func SendEmail[V any](srv *EmailService, toEmail email.Address, template email.T
c.AddHeaderRaw("Content-Type",
fmt.Sprintf("multipart/alternative;\n boundary=%s;\n charset=UTF-8", boundary),
)
c.AddHeader("MIME-Version", "1.0")
c.AddHeader("Date", time.Now().Format(time.RFC1123Z))
// to create a message-id, we need the FQDN of the sending server, but that may be a docker hostname or localhost
// so we use the domain of the from address instead (the same as Thunderbird does)
// if the address does not have an @ (which would be unusual), we use hostname
from_address := srv.appConfigService.DbConfig.SmtpFrom.Value
domain := ""
if strings.Contains(from_address, "@") {
domain = strings.Split(from_address, "@")[1]
} else {
hostname, err := os.Hostname()
if err != nil {
// can that happen? we just give up
return fmt.Errorf("failed to get own hostname: %w", err)
} else {
domain = hostname
}
}
c.AddHeader("Message-ID", "<"+uuid.New().String()+"@"+domain+">")
c.Body(body)
// Connect to the SMTP server
@@ -110,109 +132,61 @@ func (srv *EmailService) getSmtpClient() (client *smtp.Client, err error) {
smtpAddress := srv.appConfigService.DbConfig.SmtpHost.Value + ":" + port
tlsConfig := &tls.Config{
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.Value == "true",
InsecureSkipVerify: srv.appConfigService.DbConfig.SmtpSkipCertVerify.IsTrue(), //nolint:gosec
ServerName: srv.appConfigService.DbConfig.SmtpHost.Value,
}
// Connect to the SMTP server
if srv.appConfigService.DbConfig.SmtpTls.Value == "false" {
client, err = srv.connectToSmtpServer(smtpAddress)
} else if port == "465" {
client, err = srv.connectToSmtpServerUsingImplicitTLS(
smtpAddress,
tlsConfig,
)
} else {
client, err = srv.connectToSmtpServerUsingStartTLS(
// Connect to the SMTP server based on TLS setting
switch srv.appConfigService.DbConfig.SmtpTls.Value {
case "none":
client, err = smtp.Dial(smtpAddress)
case "tls":
client, err = smtp.DialTLS(smtpAddress, tlsConfig)
case "starttls":
client, err = smtp.DialStartTLS(
smtpAddress,
tlsConfig,
)
default:
return nil, fmt.Errorf("invalid SMTP TLS setting: %s", srv.appConfigService.DbConfig.SmtpTls.Value)
}
if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client.CommandTimeout = 10 * time.Second
// Send the HELO command
if err := srv.sendHelloCommand(client); err != nil {
return nil, fmt.Errorf("failed to send HELO command: %w", err)
}
// Set up the authentication if user or password are set
smtpUser := srv.appConfigService.DbConfig.SmtpUser.Value
smtpPassword := srv.appConfigService.DbConfig.SmtpPassword.Value
if smtpUser != "" || smtpPassword != "" {
auth := smtp.PlainAuth("",
srv.appConfigService.DbConfig.SmtpUser.Value,
srv.appConfigService.DbConfig.SmtpPassword.Value,
srv.appConfigService.DbConfig.SmtpHost.Value,
)
// Authenticate with plain auth
auth := sasl.NewPlainClient("", smtpUser, smtpPassword)
if err := client.Auth(auth); err != nil {
return nil, fmt.Errorf("failed to authenticate SMTP client: %w", err)
// If the server does not support plain auth, try login auth
var smtpErr *smtp.SMTPError
ok := errors.As(err, &smtpErr)
if ok && smtpErr.Code == smtp.ErrAuthUnknownMechanism.Code {
auth = sasl.NewLoginClient(smtpUser, smtpPassword)
err = client.Auth(auth)
}
// Both plain and login auth failed
if err != nil {
return nil, fmt.Errorf("failed to authenticate: %w", err)
}
}
}
return client, err
}
func (srv *EmailService) connectToSmtpServer(serverAddr string) (*smtp.Client, error) {
conn, err := netDialer.Dial("tcp", serverAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client, err := smtp.NewClient(conn, srv.appConfigService.DbConfig.SmtpHost.Value)
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
if err := srv.sendHelloCommand(client); err != nil {
return nil, fmt.Errorf("failed to say hello to SMTP server: %w", err)
}
return client, err
}
func (srv *EmailService) connectToSmtpServerUsingImplicitTLS(serverAddr string, tlsConfig *tls.Config) (*smtp.Client, error) {
tlsDialer := &tls.Dialer{
NetDialer: netDialer,
Config: tlsConfig,
}
conn, err := tlsDialer.Dial("tcp", serverAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client, err := smtp.NewClient(conn, srv.appConfigService.DbConfig.SmtpHost.Value)
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
if err := srv.sendHelloCommand(client); err != nil {
return nil, fmt.Errorf("failed to say hello to SMTP server: %w", err)
}
return client, nil
}
func (srv *EmailService) connectToSmtpServerUsingStartTLS(serverAddr string, tlsConfig *tls.Config) (*smtp.Client, error) {
conn, err := netDialer.Dial("tcp", serverAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client, err := smtp.NewClient(conn, srv.appConfigService.DbConfig.SmtpHost.Value)
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
if err := srv.sendHelloCommand(client); err != nil {
return nil, fmt.Errorf("failed to say hello to SMTP server: %w", err)
}
if err := client.StartTLS(tlsConfig); err != nil {
return nil, fmt.Errorf("failed to start TLS: %w", err)
}
return client, nil
}
func (srv *EmailService) sendHelloCommand(client *smtp.Client) error {
hostname, err := os.Hostname()
if err == nil {
@@ -224,23 +198,33 @@ func (srv *EmailService) sendHelloCommand(client *smtp.Client) error {
}
func (srv *EmailService) sendEmailContent(client *smtp.Client, toEmail email.Address, c *email.Composer) error {
if err := client.Mail(srv.appConfigService.DbConfig.SmtpFrom.Value); err != nil {
// Set the sender
if err := client.Mail(srv.appConfigService.DbConfig.SmtpFrom.Value, nil); err != nil {
return fmt.Errorf("failed to set sender: %w", err)
}
if err := client.Rcpt(toEmail.Email); err != nil {
// Set the recipient
if err := client.Rcpt(toEmail.Email, nil); err != nil {
return fmt.Errorf("failed to set recipient: %w", err)
}
// Get a writer to write the email data
w, err := client.Data()
if err != nil {
return fmt.Errorf("failed to start data: %w", err)
}
// Write the email content
_, err = w.Write([]byte(c.String()))
if err != nil {
return fmt.Errorf("failed to write email data: %w", err)
}
// Close the writer
if err := w.Close(); err != nil {
return fmt.Errorf("failed to close data writer: %w", err)
}
return nil
}

View File

@@ -31,7 +31,7 @@ var NewLoginTemplate = email.Template[NewLoginTemplateData]{
var OneTimeAccessTemplate = email.Template[OneTimeAccessTemplateData]{
Path: "one-time-access",
Title: func(data *email.TemplateData[OneTimeAccessTemplateData]) string {
return "One time access"
return "Login Code"
},
}
@@ -51,7 +51,9 @@ type NewLoginTemplateData struct {
}
type OneTimeAccessTemplateData = struct {
Link string
Code string
LoginLink string
LoginLinkWithCode string
}
// this is list of all template paths used for preloading templates

View File

@@ -3,6 +3,7 @@ package service
import (
"archive/tar"
"compress/gzip"
"context"
"errors"
"fmt"
"io"
@@ -124,8 +125,15 @@ func (s *GeoLiteService) updateDatabase() error {
log.Println("Updating GeoLite2 City database...")
downloadUrl := fmt.Sprintf(common.EnvConfig.GeoLiteDBUrl, common.EnvConfig.MaxMindLicenseKey)
// Download the database tar.gz file
resp, err := http.Get(downloadUrl)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadUrl, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download database: %w", err)
}
@@ -164,6 +172,9 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
tarReader := tar.NewReader(gzr)
var totalSize int64
const maxTotalSize = 300 * 1024 * 1024 // 300 MB limit for total decompressed size
// Iterate over the files in the tar archive
for {
header, err := tarReader.Next()
@@ -176,6 +187,11 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
// Check if the file is the GeoLite2-City.mmdb file
if header.Typeflag == tar.TypeReg && filepath.Base(header.Name) == "GeoLite2-City.mmdb" {
totalSize += header.Size
if totalSize > maxTotalSize {
return errors.New("total decompressed size exceeds maximum allowed limit")
}
// extract to a temporary file to avoid having a corrupted db in case of write failure.
baseDir := filepath.Dir(common.EnvConfig.GeoLiteDBPath)
tmpFile, err := os.CreateTemp(baseDir, "geolite.*.mmdb.tmp")
@@ -185,7 +201,7 @@ func (s *GeoLiteService) extractDatabase(reader io.Reader) error {
tempName := tmpFile.Name()
// Write the file contents directly to the target location
if _, err := io.Copy(tmpFile, tarReader); err != nil {
if _, err := io.Copy(tmpFile, tarReader); err != nil { //nolint:gosec
// if fails to write, then cleanup and throw an error
tmpFile.Close()
os.Remove(tempName)

View File

@@ -3,297 +3,495 @@ package service
import (
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"math/big"
"os"
"path/filepath"
"slices"
"strconv"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v3/jwa"
"github.com/lestrrat-go/jwx/v3/jwk"
"github.com/lestrrat-go/jwx/v3/jwt"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/pocket-id/pocket-id/backend/internal/utils"
)
const (
privateKeyPath = "data/keys/jwt_private_key.pem"
publicKeyPath = "data/keys/jwt_public_key.pem"
// PrivateKeyFile is the path in the data/keys folder where the key is stored
// This is a JSON file containing a key encoded as JWK
PrivateKeyFile = "jwt_private_key.json"
// RsaKeySize is the size, in bits, of the RSA key to generate if none is found
RsaKeySize = 2048
// KeyUsageSigning is the usage for the private keys, for the "use" property
KeyUsageSigning = "sig"
// IsAdminClaim is a boolean claim used in access tokens for admin users
// This may be omitted on non-admin tokens
IsAdminClaim = "isAdmin"
// Acceptable clock skew for verifying tokens
clockSkew = time.Minute
)
type JwtService struct {
PublicKey *rsa.PublicKey
PrivateKey *rsa.PrivateKey
privateKey jwk.Key
keyId string
appConfigService *AppConfigService
jwksEncoded []byte
}
func NewJwtService(appConfigService *AppConfigService) *JwtService {
service := &JwtService{
appConfigService: appConfigService,
}
service := &JwtService{}
// Ensure keys are generated or loaded
if err := service.loadOrGenerateKeys(); err != nil {
if err := service.init(appConfigService, common.EnvConfig.KeysPath); err != nil {
log.Fatalf("Failed to initialize jwt service: %v", err)
}
return service
}
type AccessTokenJWTClaims struct {
jwt.RegisteredClaims
IsAdmin bool `json:"isAdmin,omitempty"`
func (s *JwtService) init(appConfigService *AppConfigService, keysPath string) error {
s.appConfigService = appConfigService
// Ensure keys are generated or loaded
return s.loadOrGenerateKey(keysPath)
}
type JWK struct {
Kid string `json:"kid"`
Kty string `json:"kty"`
Use string `json:"use"`
Alg string `json:"alg"`
N string `json:"n"`
E string `json:"e"`
}
// loadOrGenerateKey loads the private key from the given path or generates it if not existing.
func (s *JwtService) loadOrGenerateKey(keysPath string) error {
var key jwk.Key
// loadOrGenerateKeys loads RSA keys from the given paths or generates them if they do not exist.
func (s *JwtService) loadOrGenerateKeys() error {
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
if err := s.generateKeys(); err != nil {
return err
// First, check if we have a JWK file
// If we do, then we just load that
jwkPath := filepath.Join(keysPath, PrivateKeyFile)
ok, err := utils.FileExists(jwkPath)
if err != nil {
return fmt.Errorf("failed to check if private key file (JWK) exists at path '%s': %w", jwkPath, err)
}
if ok {
key, err = s.loadKeyJWK(jwkPath)
if err != nil {
return fmt.Errorf("failed to load private key file (JWK) at path '%s': %w", jwkPath, err)
}
// Set the key, and we are done
err = s.SetKey(key)
if err != nil {
return fmt.Errorf("failed to set private key: %w", err)
}
return nil
}
privateKeyBytes, err := os.ReadFile(privateKeyPath)
// If we are here, we need to generate a new key
key, err = s.generateNewRSAKey()
if err != nil {
return errors.New("can't read jwt private key: " + err.Error())
}
s.PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
if err != nil {
return errors.New("can't parse jwt private key: " + err.Error())
return fmt.Errorf("failed to generate new private key: %w", err)
}
publicKeyBytes, err := os.ReadFile(publicKeyPath)
// Set the key in the object, which also validates it
err = s.SetKey(key)
if err != nil {
return errors.New("can't read jwt public key: " + err.Error())
return fmt.Errorf("failed to set private key: %w", err)
}
s.PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
// Save the key as JWK
err = SaveKeyJWK(s.privateKey, jwkPath)
if err != nil {
return errors.New("can't parse jwt public key: " + err.Error())
return fmt.Errorf("failed to save private key file at path '%s': %w", jwkPath, err)
}
return nil
}
func ValidateKey(privateKey jwk.Key) error {
// Validate the loaded key
err := privateKey.Validate()
if err != nil {
return fmt.Errorf("key object is invalid: %w", err)
}
keyID, ok := privateKey.KeyID()
if !ok || keyID == "" {
return errors.New("key object does not contain a key ID")
}
usage, ok := privateKey.KeyUsage()
if !ok || usage != KeyUsageSigning {
return errors.New("key object is not valid for signing")
}
ok, err = jwk.IsPrivateKey(privateKey)
if err != nil || !ok {
return errors.New("key object is not a private key")
}
return nil
}
func (s *JwtService) SetKey(privateKey jwk.Key) error {
// Validate the loaded key
err := ValidateKey(privateKey)
if err != nil {
return fmt.Errorf("private key is not valid: %w", err)
}
// Set the private key and key id in the object
s.privateKey = privateKey
keyId, ok := privateKey.KeyID()
if !ok {
return errors.New("key object does not contain a key ID")
}
s.keyId = keyId
// Create and encode a JWKS containing the public key
publicKey, err := s.GetPublicJWK()
if err != nil {
return fmt.Errorf("failed to get public JWK: %w", err)
}
jwks := jwk.NewSet()
err = jwks.AddKey(publicKey)
if err != nil {
return fmt.Errorf("failed to add public key to JWKS: %w", err)
}
s.jwksEncoded, err = json.Marshal(jwks)
if err != nil {
return fmt.Errorf("failed to encode JWKS to JSON: %w", err)
}
return nil
}
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
sessionDurationInMinutes, _ := strconv.Atoi(s.appConfigService.DbConfig.SessionDuration.Value)
claim := AccessTokenJWTClaims{
RegisteredClaims: jwt.RegisteredClaims{
Subject: user.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Audience: jwt.ClaimStrings{common.EnvConfig.AppURL},
},
IsAdmin: user.IsAdmin,
}
kid, err := s.generateKeyID(s.PublicKey)
now := time.Now()
token, err := jwt.NewBuilder().
Subject(user.ID).
Expiration(now.Add(s.appConfigService.DbConfig.SessionDuration.AsDurationMinutes())).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Build()
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
return "", fmt.Errorf("failed to build token: %w", err)
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = kid
err = SetAudienceString(token, common.EnvConfig.AppURL)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
return token.SignedString(s.PrivateKey)
err = SetIsAdmin(token, user.IsAdmin)
if err != nil {
return "", fmt.Errorf("failed to set 'isAdmin' claim in token: %w", err)
}
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return string(signed), nil
}
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.PublicKey, nil
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
func (s *JwtService) VerifyAccessToken(tokenString string) (jwt.Token, error) {
alg, _ := s.privateKey.Algorithm()
token, err := jwt.ParseString(
tokenString,
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithAudience(common.EnvConfig.AppURL),
jwt.WithIssuer(common.EnvConfig.AppURL),
)
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
claims, isValid := token.Claims.(*AccessTokenJWTClaims)
if !isValid {
return nil, errors.New("can't parse claims")
}
if !slices.Contains(claims.Audience, common.EnvConfig.AppURL) {
return nil, errors.New("audience doesn't match")
}
return claims, nil
return token, nil
}
func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID string, nonce string) (string, error) {
claims := jwt.MapClaims{
"aud": clientID,
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
"iat": jwt.NewNumericDate(time.Now()),
"iss": common.EnvConfig.AppURL,
func (s *JwtService) GenerateIDToken(userClaims map[string]any, clientID string, nonce string) (string, error) {
now := time.Now()
token, err := jwt.NewBuilder().
Expiration(now.Add(1 * time.Hour)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Build()
if err != nil {
return "", fmt.Errorf("failed to build token: %w", err)
}
err = SetAudienceString(token, clientID)
if err != nil {
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
for k, v := range userClaims {
claims[k] = v
err = token.Set(k, v)
if err != nil {
return "", fmt.Errorf("failed to set claim '%s': %w", k, err)
}
}
if nonce != "" {
claims["nonce"] = nonce
err = token.Set("nonce", nonce)
if err != nil {
return "", fmt.Errorf("failed to set claim 'nonce': %w", err)
}
}
kid, err := s.generateKeyID(s.PublicKey)
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
return "", fmt.Errorf("failed to sign token: %w", err)
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = kid
return string(signed), nil
}
return token.SignedString(s.PrivateKey)
func (s *JwtService) VerifyIdToken(tokenString string, acceptExpiredTokens bool) (jwt.Token, error) {
alg, _ := s.privateKey.Algorithm()
opts := make([]jwt.ParseOption, 0)
// These options are always present
opts = append(opts,
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
)
// By default, jwt.Parse includes 3 default validators for "nbf", "iat", and "exp"
// In case we want to accept expired tokens (during logout), we need to set the validators explicitly without validating "exp"
if acceptExpiredTokens {
// This is equivalent to the default validators except it doesn't validate "exp"
opts = append(opts,
jwt.WithResetValidators(true),
jwt.WithValidator(jwt.IsIssuedAtValid()),
jwt.WithValidator(jwt.IsNbfValid()),
)
}
token, err := jwt.ParseString(tokenString, opts...)
if err != nil {
return nil, fmt.Errorf("failed to parse token: %w", err)
}
return token, nil
}
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
claim := jwt.RegisteredClaims{
Subject: user.ID,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Audience: jwt.ClaimStrings{clientID},
Issuer: common.EnvConfig.AppURL,
}
kid, err := s.generateKeyID(s.PublicKey)
now := time.Now()
token, err := jwt.NewBuilder().
Subject(user.ID).
Expiration(now.Add(1 * time.Hour)).
IssuedAt(now).
Issuer(common.EnvConfig.AppURL).
Build()
if err != nil {
return "", errors.New("failed to generate key ID: " + err.Error())
return "", fmt.Errorf("failed to build token: %w", err)
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
token.Header["kid"] = kid
return token.SignedString(s.PrivateKey)
}
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.PublicKey, nil
})
if err != nil || !token.Valid {
return nil, errors.New("couldn't handle this token")
}
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
if !isValid {
return nil, errors.New("can't parse claims")
}
return claims, nil
}
func (s *JwtService) VerifyIdToken(tokenString string) (*jwt.RegisteredClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
return s.PublicKey, nil
}, jwt.WithIssuer(common.EnvConfig.AppURL))
if err != nil && !errors.Is(err, jwt.ErrTokenExpired) {
return nil, errors.New("couldn't handle this token")
}
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
if !isValid {
return nil, errors.New("can't parse claims")
}
return claims, nil
}
// GetJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetJWK() (JWK, error) {
if s.PublicKey == nil {
return JWK{}, errors.New("public key is not initialized")
}
kid, err := s.generateKeyID(s.PublicKey)
err = SetAudienceString(token, clientID)
if err != nil {
return JWK{}, err
return "", fmt.Errorf("failed to set 'aud' claim in token: %w", err)
}
jwk := JWK{
Kid: kid,
Kty: "RSA",
Use: "sig",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(s.PublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.PublicKey.E)).Bytes()),
alg, _ := s.privateKey.Algorithm()
signed, err := jwt.Sign(token, jwt.WithKey(alg, s.privateKey))
if err != nil {
return "", fmt.Errorf("failed to sign token: %w", err)
}
return jwk, nil
return string(signed), nil
}
// GenerateKeyID generates a Key ID for the public key using the first 8 bytes of the SHA-256 hash of the public key.
func (s *JwtService) generateKeyID(publicKey *rsa.PublicKey) (string, error) {
pubASN1, err := x509.MarshalPKIXPublicKey(publicKey)
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (jwt.Token, error) {
alg, _ := s.privateKey.Algorithm()
token, err := jwt.ParseString(
tokenString,
jwt.WithValidate(true),
jwt.WithKey(alg, s.privateKey),
jwt.WithAcceptableSkew(clockSkew),
jwt.WithIssuer(common.EnvConfig.AppURL),
)
if err != nil {
return "", errors.New("failed to marshal public key: " + err.Error())
return nil, fmt.Errorf("failed to parse token: %w", err)
}
// Compute SHA-256 hash of the public key
hash := sha256.New()
hash.Write(pubASN1)
hashed := hash.Sum(nil)
// Truncate the hash to the first 8 bytes for a shorter Key ID
shortHash := hashed[:8]
// Return Base64 encoded truncated hash as Key ID
return base64.RawURLEncoding.EncodeToString(shortHash), nil
return token, nil
}
// generateKeys generates a new RSA key pair and saves them to the specified paths.
func (s *JwtService) generateKeys() error {
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
return errors.New("failed to create directories for keys: " + err.Error())
// GetPublicJWK returns the JSON Web Key (JWK) for the public key.
func (s *JwtService) GetPublicJWK() (jwk.Key, error) {
if s.privateKey == nil {
return nil, errors.New("key is not initialized")
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
pubKey, err := s.privateKey.PublicKey()
if err != nil {
return errors.New("failed to generate private key: " + err.Error())
}
s.PrivateKey = privateKey
if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil {
return err
return nil, fmt.Errorf("failed to get public key: %w", err)
}
publicKey := &privateKey.PublicKey
s.PublicKey = publicKey
EnsureAlgInKey(pubKey)
if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil {
return err
}
return nil
return pubKey, nil
}
// savePEMKey saves a PEM encoded key to a file.
func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error {
keyFile, err := os.Create(path)
// GetPublicJWKSAsJSON returns the JSON Web Key Set (JWKS) for the public key, encoded as JSON.
// The value is cached since the key is static.
func (s *JwtService) GetPublicJWKSAsJSON() ([]byte, error) {
if len(s.jwksEncoded) == 0 {
return nil, errors.New("key is not initialized")
}
return s.jwksEncoded, nil
}
// GetKeyAlg returns the algorithm of the key
func (s *JwtService) GetKeyAlg() (jwa.KeyAlgorithm, error) {
if len(s.jwksEncoded) == 0 {
return nil, errors.New("key is not initialized")
}
alg, ok := s.privateKey.Algorithm()
if !ok || alg == nil {
return nil, errors.New("failed to retrieve algorithm for key")
}
return alg, nil
}
func (s *JwtService) loadKeyJWK(path string) (jwk.Key, error) {
data, err := os.ReadFile(path)
if err != nil {
return errors.New("failed to create key file: " + err.Error())
return nil, fmt.Errorf("failed to read key data: %w", err)
}
key, err := jwk.ParseKey(data)
if err != nil {
return nil, fmt.Errorf("failed to parse key: %w", err)
}
return key, nil
}
// EnsureAlgInKey ensures that the key contains an "alg" parameter, set depending on the key type
func EnsureAlgInKey(key jwk.Key) {
_, ok := key.Algorithm()
if ok {
// Algorithm is already set
return
}
switch key.KeyType() {
case jwa.RSA():
// Default to RS256 for RSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.RS256())
case jwa.EC():
// Default to ES256 for ECDSA keys
_ = key.Set(jwk.AlgorithmKey, jwa.ES256())
case jwa.OKP():
// Default to EdDSA for OKP keys
_ = key.Set(jwk.AlgorithmKey, jwa.EdDSA())
}
}
func (s *JwtService) generateNewRSAKey() (jwk.Key, error) {
// We generate RSA keys only
rawKey, err := rsa.GenerateKey(rand.Reader, RsaKeySize)
if err != nil {
return nil, fmt.Errorf("failed to generate RSA private key: %w", err)
}
// Import the raw key
return importRawKey(rawKey)
}
func importRawKey(rawKey any) (jwk.Key, error) {
key, err := jwk.Import(rawKey)
if err != nil {
return nil, fmt.Errorf("failed to import generated private key: %w", err)
}
// Generate the key ID
kid, err := generateRandomKeyID()
if err != nil {
return nil, fmt.Errorf("failed to generate key ID: %w", err)
}
_ = key.Set(jwk.KeyIDKey, kid)
// Set other required fields
_ = key.Set(jwk.KeyUsageKey, KeyUsageSigning)
EnsureAlgInKey(key)
return key, err
}
// SaveKeyJWK saves a JWK to a file
func SaveKeyJWK(key jwk.Key, path string) error {
dir := filepath.Dir(path)
err := os.MkdirAll(dir, 0700)
if err != nil {
return fmt.Errorf("failed to create directory '%s' for key file: %w", dir, err)
}
keyFile, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("failed to create key file: %w", err)
}
defer keyFile.Close()
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: keyType,
Bytes: keyBytes,
})
if _, err := keyFile.Write(keyPEM); err != nil {
return errors.New("failed to write key file: " + err.Error())
// Write the JSON file to disk
enc := json.NewEncoder(keyFile)
enc.SetEscapeHTML(false)
err = enc.Encode(key)
if err != nil {
return fmt.Errorf("failed to write key file: %w", err)
}
return nil
}
// generateRandomKeyID generates a random key ID.
func generateRandomKeyID() (string, error) {
buf := make([]byte, 8)
_, err := io.ReadFull(rand.Reader, buf)
if err != nil {
return "", fmt.Errorf("failed to read random bytes: %w", err)
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
// GetIsAdmin returns the value of the "isAdmin" claim in the token
func GetIsAdmin(token jwt.Token) (bool, error) {
if !token.Has(IsAdminClaim) {
return false, nil
}
var isAdmin bool
err := token.Get(IsAdminClaim, &isAdmin)
return isAdmin, err
}
// SetIsAdmin sets the "isAdmin" claim in the token
func SetIsAdmin(token jwt.Token, isAdmin bool) error {
// Only set if true
if !isAdmin {
return nil
}
return token.Set(IsAdminClaim, isAdmin)
}
// SetAudienceString sets the "aud" claim with a value that is a string, and not an array
// This is permitted by RFC 7519, and it's done here for backwards-compatibility
func SetAudienceString(token jwt.Token, audience string) error {
return token.Set(jwt.AudienceKey, audience)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,18 @@
package service
import (
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
"github.com/go-ldap/ldap/v3"
"github.com/pocket-id/pocket-id/backend/internal/dto"
@@ -24,13 +32,13 @@ func NewLdapService(db *gorm.DB, appConfigService *AppConfigService, userService
}
func (s *LdapService) createClient() (*ldap.Conn, error) {
if s.appConfigService.DbConfig.LdapEnabled.Value != "true" {
if !s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return nil, fmt.Errorf("LDAP is not enabled")
}
// Setup LDAP connection
ldapURL := s.appConfigService.DbConfig.LdapUrl.Value
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.Value == "true"
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify}))
skipTLSVerify := s.appConfigService.DbConfig.LdapSkipCertVerify.IsTrue()
client, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: skipTLSVerify})) //nolint:gosec
if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP: %w", err)
}
@@ -59,6 +67,7 @@ func (s *LdapService) SyncAll() error {
return nil
}
//nolint:gocognit
func (s *LdapService) SyncGroups() error {
// Setup LDAP connection
client, err := s.createClient()
@@ -89,10 +98,16 @@ func (s *LdapService) SyncGroups() error {
ldapGroupIDs := make(map[string]bool)
for _, value := range result.Entries {
var usersToAddDto dto.UserGroupUpdateUsersDto
var membersUserId []string
ldapId := value.GetAttributeValue(uniqueIdentifierAttribute)
// Skip groups without a valid LDAP ID
if ldapId == "" {
log.Printf("Skipping LDAP group without a valid unique identifier (attribute: %s)", uniqueIdentifierAttribute)
continue
}
ldapGroupIDs[ldapId] = true
// Try to find the group in the database
@@ -107,7 +122,16 @@ func (s *LdapService) SyncGroups() error {
singleMember := strings.Split(strings.Split(member, "=")[1], ",")[0]
var databaseUser model.User
s.db.Where("username = ?", singleMember).Where("ldap_id IS NOT NULL").First(&databaseUser)
err := s.db.Where("username = ? AND ldap_id IS NOT NULL", singleMember).First(&databaseUser).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// The user collides with a non-LDAP user, so we skip it
continue
} else {
return err
}
}
membersUserId = append(membersUserId, databaseUser.ID)
}
@@ -118,22 +142,21 @@ func (s *LdapService) SyncGroups() error {
LdapID: value.GetAttributeValue(uniqueIdentifierAttribute),
}
usersToAddDto = dto.UserGroupUpdateUsersDto{
UserIDs: membersUserId,
}
if databaseGroup.ID == "" {
newGroup, err := s.groupService.Create(syncGroup)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
} else {
if _, err = s.groupService.UpdateUsers(newGroup.ID, usersToAddDto); err != nil {
if _, err = s.groupService.UpdateUsers(newGroup.ID, membersUserId); err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
}
} else {
_, err = s.groupService.Update(databaseGroup.ID, syncGroup, true)
_, err = s.groupService.UpdateUsers(databaseGroup.ID, usersToAddDto)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
}
_, err = s.groupService.UpdateUsers(databaseGroup.ID, membersUserId)
if err != nil {
log.Printf("Error syncing group %s: %s", syncGroup.Name, err)
return err
@@ -146,7 +169,7 @@ func (s *LdapService) SyncGroups() error {
// Get all LDAP groups from the database
var ldapGroupsInDb []model.UserGroup
if err := s.db.Find(&ldapGroupsInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
fmt.Println(fmt.Errorf("failed to fetch groups from database: %v", err))
fmt.Println(fmt.Errorf("failed to fetch groups from database: %w", err))
}
// Delete groups that no longer exist in LDAP
@@ -163,6 +186,7 @@ func (s *LdapService) SyncGroups() error {
return nil
}
//nolint:gocognit
func (s *LdapService) SyncUsers() error {
// Setup LDAP connection
client, err := s.createClient()
@@ -177,6 +201,7 @@ func (s *LdapService) SyncUsers() error {
emailAttribute := s.appConfigService.DbConfig.LdapAttributeUserEmail.Value
firstNameAttribute := s.appConfigService.DbConfig.LdapAttributeUserFirstName.Value
lastNameAttribute := s.appConfigService.DbConfig.LdapAttributeUserLastName.Value
profilePictureAttribute := s.appConfigService.DbConfig.LdapAttributeUserProfilePicture.Value
adminGroupAttribute := s.appConfigService.DbConfig.LdapAttributeAdminGroup.Value
filter := s.appConfigService.DbConfig.LdapUserSearchFilter.Value
@@ -189,6 +214,7 @@ func (s *LdapService) SyncUsers() error {
emailAttribute,
firstNameAttribute,
lastNameAttribute,
profilePictureAttribute,
}
// Filters must start and finish with ()!
@@ -204,6 +230,13 @@ func (s *LdapService) SyncUsers() error {
for _, value := range result.Entries {
ldapId := value.GetAttributeValue(uniqueIdentifierAttribute)
// Skip users without a valid LDAP ID
if ldapId == "" {
log.Printf("Skipping LDAP user without a valid unique identifier (attribute: %s)", uniqueIdentifierAttribute)
continue
}
ldapUserIDs[ldapId] = true
// Get the user from the database
@@ -237,21 +270,26 @@ func (s *LdapService) SyncUsers() error {
if err != nil {
log.Printf("Error syncing user %s: %s", newUser.Username, err)
}
}
// Save profile picture
if pictureString := value.GetAttributeValue(profilePictureAttribute); pictureString != "" {
if err := s.SaveProfilePicture(databaseUser.ID, pictureString); err != nil {
log.Printf("Error saving profile picture for user %s: %s", newUser.Username, err)
}
}
}
// Get all LDAP users from the database
var ldapUsersInDb []model.User
if err := s.db.Find(&ldapUsersInDb, "ldap_id IS NOT NULL").Select("ldap_id").Error; err != nil {
fmt.Println(fmt.Errorf("failed to fetch users from database: %v", err))
fmt.Println(fmt.Errorf("failed to fetch users from database: %w", err))
}
// Delete users that no longer exist in LDAP
for _, user := range ldapUsersInDb {
if _, exists := ldapUserIDs[*user.LdapID]; !exists {
if err := s.db.Delete(&model.User{}, "ldap_id = ?", user.LdapID).Error; err != nil {
if err := s.userService.DeleteUser(user.ID, true); err != nil {
log.Printf("Failed to delete user %s with: %v", user.Username, err)
} else {
log.Printf("Deleted user %s", user.Username)
@@ -260,3 +298,40 @@ func (s *LdapService) SyncUsers() error {
}
return nil
}
func (s *LdapService) SaveProfilePicture(userId string, pictureString string) error {
var reader io.Reader
if _, err := url.ParseRequestURI(pictureString); err == nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, pictureString, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
response, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download profile picture: %w", err)
}
defer response.Body.Close()
reader = response.Body
} else if decodedPhoto, err := base64.StdEncoding.DecodeString(pictureString); err == nil {
// If the photo is a base64 encoded string, decode it
reader = bytes.NewReader(decodedPhoto)
} else {
// If the photo is a string, we assume that it's a binary string
reader = bytes.NewReader([]byte(pictureString))
}
// Update the profile picture
if err := s.userService.UpdateProfilePicture(userId, reader); err != nil {
return fmt.Errorf("failed to update profile picture: %w", err)
}
return nil
}

View File

@@ -145,60 +145,136 @@ func (s *OidcService) IsUserGroupAllowedToAuthorize(user model.User, client mode
return isAllowedToAuthorize
}
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) {
if grantType != "authorization_code" {
return "", "", &common.OidcGrantTypeNotSupportedError{}
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier, refreshToken string) (idToken string, accessToken string, newRefreshToken string, exp int, err error) {
switch grantType {
case "authorization_code":
return s.createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier)
case "refresh_token":
accessToken, newRefreshToken, exp, err = s.createTokenFromRefreshToken(refreshToken, clientID, clientSecret)
return "", accessToken, newRefreshToken, exp, err
default:
return "", "", "", 0, &common.OidcGrantTypeNotSupportedError{}
}
}
func (s *OidcService) createTokenFromAuthorizationCode(code, clientID, clientSecret, codeVerifier string) (idToken string, accessToken string, refreshToken string, exp int, err error) {
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", err
return "", "", "", 0, err
}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{}
return "", "", "", 0, &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", &common.OidcClientSecretInvalidError{}
return "", "", "", 0, &common.OidcClientSecretInvalidError{}
}
}
var authorizationCodeMetaData model.OidcAuthorizationCode
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
if err != nil {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
}
// If the client is public or PKCE is enabled, the code verifier must match the code challenge
if client.IsPublic || client.PkceEnabled {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", &common.OidcInvalidCodeVerifierError{}
return "", "", "", 0, &common.OidcInvalidCodeVerifierError{}
}
}
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
return "", "", "", 0, &common.OidcInvalidAuthorizationCodeError{}
}
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
if err != nil {
return "", "", err
return "", "", "", 0, err
}
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
idToken, err = s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
if err != nil {
return "", "", err
return "", "", "", 0, err
}
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
// Generate a refresh token
refreshToken, err = s.createRefreshToken(clientID, authorizationCodeMetaData.UserID, authorizationCodeMetaData.Scope)
if err != nil {
return "", "", "", 0, err
}
accessToken, err = s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
if err != nil {
return "", "", "", 0, err
}
s.db.Delete(&authorizationCodeMetaData)
return idToken, accessToken, nil
return idToken, accessToken, refreshToken, 3600, nil
}
func (s *OidcService) createTokenFromRefreshToken(refreshToken, clientID, clientSecret string) (accessToken string, newRefreshToken string, exp int, err error) {
if refreshToken == "" {
return "", "", 0, &common.OidcMissingRefreshTokenError{}
}
// Get the client to check if it's public
var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", 0, err
}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", 0, &common.OidcMissingClientCredentialsError{}
}
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", 0, &common.OidcClientSecretInvalidError{}
}
}
// Verify refresh token
var storedRefreshToken model.OidcRefreshToken
err = s.db.Preload("User").
Where("token = ? AND expires_at > ?", utils.CreateSha256Hash(refreshToken), datatype.DateTime(time.Now())).
First(&storedRefreshToken).
Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
}
return "", "", 0, err
}
// Verify that the refresh token belongs to the provided client
if storedRefreshToken.ClientID != clientID {
return "", "", 0, &common.OidcInvalidRefreshTokenError{}
}
// Generate a new access token
accessToken, err = s.jwtService.GenerateOauthAccessToken(storedRefreshToken.User, clientID)
if err != nil {
return "", "", 0, err
}
// Generate a new refresh token and invalidate the old one
newRefreshToken, err = s.createRefreshToken(clientID, storedRefreshToken.UserID, storedRefreshToken.Scope)
if err != nil {
return "", "", 0, err
}
// Delete the used refresh token
s.db.Delete(&storedRefreshToken)
return accessToken, newRefreshToken, 3600, nil
}
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
@@ -385,7 +461,7 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
if strings.Contains(scope, "email") {
claims["email"] = user.Email
claims["email_verified"] = s.appConfigService.DbConfig.EmailsVerified.Value == "true"
claims["email_verified"] = s.appConfigService.DbConfig.EmailsVerified.IsTrue()
}
if strings.Contains(scope, "groups") {
@@ -401,6 +477,7 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
"family_name": user.LastName,
"name": user.FullName(),
"preferred_username": user.Username,
"picture": fmt.Sprintf("%s/api/users/%s/profile-picture.png", common.EnvConfig.AppURL, user.ID),
}
if strings.Contains(scope, "profile") {
@@ -418,8 +495,8 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
for _, customClaim := range customClaims {
// The value of the custom claim can be a JSON object or a string
var jsonValue interface{}
json.Unmarshal([]byte(customClaim.Value), &jsonValue)
if jsonValue != nil {
err := json.Unmarshal([]byte(customClaim.Value), &jsonValue)
if err == nil {
// It's JSON so we store it as an object
claims[customClaim.Key] = jsonValue
} else {
@@ -470,21 +547,24 @@ func (s *OidcService) ValidateEndSession(input dto.OidcLogoutDto, userID string)
}
// If the ID token hint is provided, verify the ID token
claims, err := s.jwtService.VerifyIdToken(input.IdTokenHint)
// Here we also accept expired ID tokens, which are fine per spec
token, err := s.jwtService.VerifyIdToken(input.IdTokenHint, true)
if err != nil {
return "", &common.TokenInvalidError{}
}
// If the client ID is provided check if the client ID in the ID token matches the client ID in the request
if input.ClientId != "" && claims.Audience[0] != input.ClientId {
clientID, ok := token.Audience()
if !ok || len(clientID) == 0 {
return "", &common.TokenInvalidError{}
}
if input.ClientId != "" && clientID[0] != input.ClientId {
return "", &common.OidcClientIdNotMatchingError{}
}
clientId := claims.Audience[0]
// Check if the user has authorized the client before
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientId, userID).Error; err != nil {
if err := s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", clientID[0], userID).Error; err != nil {
return "", &common.OidcMissingAuthorizationError{}
}
@@ -566,3 +646,28 @@ func (s *OidcService) getCallbackURL(urls []string, inputCallbackURL string) (ca
return "", &common.OidcInvalidCallbackURLError{}
}
func (s *OidcService) createRefreshToken(clientID string, userID string, scope string) (string, error) {
refreshToken, err := utils.GenerateRandomAlphanumericString(40)
if err != nil {
return "", err
}
// Compute the hash of the refresh token to store in the DB
// Refresh tokens are pretty long already, so a "simple" SHA-256 hash is enough
refreshTokenHash := utils.CreateSha256Hash(refreshToken)
m := model.OidcRefreshToken{
ExpiresAt: datatype.DateTime(time.Now().Add(30 * 24 * time.Hour)), // 30 days
Token: refreshTokenHash,
ClientID: clientID,
UserID: userID,
Scope: scope,
}
if err := s.db.Create(&m).Error; err != nil {
return "", err
}
return refreshToken, nil
}

View File

@@ -54,7 +54,7 @@ func (s *UserGroupService) Delete(id string) error {
}
// Disallow deleting the group if it is an LDAP group and LDAP is enabled
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" {
if group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return &common.LdapUserGroupUpdateError{}
}
@@ -87,7 +87,7 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
}
// Disallow updating the group if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" {
if !allowLdapUpdate && group.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return model.UserGroup{}, &common.LdapUserGroupUpdateError{}
}
@@ -103,16 +103,16 @@ func (s *UserGroupService) Update(id string, input dto.UserGroupCreateDto, allow
return group, nil
}
func (s *UserGroupService) UpdateUsers(id string, input dto.UserGroupUpdateUsersDto) (group model.UserGroup, err error) {
func (s *UserGroupService) UpdateUsers(id string, userIds []string) (group model.UserGroup, err error) {
group, err = s.Get(id)
if err != nil {
return model.UserGroup{}, err
}
// Fetch the users based on UserIDs in input
// Fetch the users based on the userIds
var users []model.User
if len(input.UserIDs) > 0 {
if err := s.db.Where("id IN (?)", input.UserIDs).Find(&users).Error; err != nil {
if len(userIds) > 0 {
if err := s.db.Where("id IN (?)", userIds).Find(&users).Error; err != nil {
return model.UserGroup{}, err
}
}

View File

@@ -3,11 +3,16 @@ package service
import (
"errors"
"fmt"
"io"
"log"
"net/url"
"os"
"strings"
"time"
"github.com/google/uuid"
profilepicture "github.com/pocket-id/pocket-id/backend/internal/utils/image"
"github.com/pocket-id/pocket-id/backend/internal/common"
"github.com/pocket-id/pocket-id/backend/internal/dto"
"github.com/pocket-id/pocket-id/backend/internal/model"
@@ -44,21 +49,95 @@ func (s *UserService) ListUsers(searchTerm string, sortedPaginationRequest utils
func (s *UserService) GetUser(userID string) (model.User, error) {
var user model.User
err := s.db.Preload("CustomClaims").Where("id = ?", userID).First(&user).Error
err := s.db.Preload("UserGroups").Preload("CustomClaims").Where("id = ?", userID).First(&user).Error
return user, err
}
func (s *UserService) DeleteUser(userID string) error {
func (s *UserService) GetProfilePicture(userID string) (io.Reader, int64, error) {
// Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil {
return nil, 0, &common.InvalidUUIDError{}
}
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
file, err := os.Open(profilePicturePath)
if err == nil {
// Get the file size
fileInfo, err := file.Stat()
if err != nil {
return nil, 0, err
}
return file, fileInfo.Size(), nil
}
// If the file does not exist, return the default profile picture
user, err := s.GetUser(userID)
if err != nil {
return nil, 0, err
}
defaultPicture, err := profilepicture.CreateDefaultProfilePicture(user.FirstName, user.LastName)
if err != nil {
return nil, 0, err
}
return defaultPicture, int64(defaultPicture.Len()), nil
}
func (s *UserService) GetUserGroups(userID string) ([]model.UserGroup, error) {
var user model.User
if err := s.db.Preload("UserGroups").Where("id = ?", userID).First(&user).Error; err != nil {
return nil, err
}
return user.UserGroups, nil
}
func (s *UserService) UpdateProfilePicture(userID string, file io.Reader) error {
// Validate the user ID to prevent directory traversal
err := uuid.Validate(userID)
if err != nil {
return &common.InvalidUUIDError{}
}
// Convert the image to a smaller square image
profilePicture, err := profilepicture.CreateProfilePicture(file)
if err != nil {
return err
}
// Ensure the directory exists
profilePictureDir := common.EnvConfig.UploadPath + "/profile-pictures"
err = os.MkdirAll(profilePictureDir, os.ModePerm)
if err != nil {
return err
}
// Create the profile picture file
err = utils.SaveFileStream(profilePicture, profilePictureDir+"/"+userID+".png")
if err != nil {
return err
}
return nil
}
func (s *UserService) DeleteUser(userID string, allowLdapDelete bool) error {
var user model.User
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
return err
}
// Disallow deleting the user if it is an LDAP user and LDAP is enabled
if user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" {
if !allowLdapDelete && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return &common.LdapUserUpdateError{}
}
// Delete the profile picture
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
if err := os.Remove(profilePicturePath); err != nil && !os.IsNotExist(err) {
return err
}
return s.db.Delete(&user).Error
}
@@ -69,6 +148,7 @@ func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
Email: input.Email,
Username: input.Username,
IsAdmin: input.IsAdmin,
Locale: input.Locale,
}
if input.LdapID != "" {
user.LdapID = &input.LdapID
@@ -90,7 +170,7 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
}
// Disallow updating the user if it is an LDAP group and LDAP is enabled
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.Value == "true" {
if !allowLdapUpdate && user.LdapID != nil && s.appConfigService.DbConfig.LdapEnabled.IsTrue() {
return model.User{}, &common.LdapUserUpdateError{}
}
@@ -98,6 +178,7 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
user.LastName = updatedUser.LastName
user.Email = updatedUser.Email
user.Username = updatedUser.Username
user.Locale = updatedUser.Locale
if !updateOwnUser {
user.IsAdmin = updatedUser.IsAdmin
}
@@ -113,6 +194,11 @@ func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, u
}
func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath string) error {
isDisabled := !s.appConfigService.DbConfig.EmailOneTimeAccessEnabled.IsTrue()
if isDisabled {
return &common.OneTimeAccessDisabledError{}
}
var user model.User
if err := s.db.Where("email = ?", emailAddress).First(&user).Error; err != nil {
// Do not return error if user not found to prevent email enumeration
@@ -123,17 +209,18 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin
}
}
oneTimeAccessToken, err := s.CreateOneTimeAccessToken(user.ID, time.Now().Add(time.Hour))
oneTimeAccessToken, err := s.CreateOneTimeAccessToken(user.ID, time.Now().Add(15*time.Minute))
if err != nil {
return err
}
link := fmt.Sprintf("%s/login/%s", common.EnvConfig.AppURL, oneTimeAccessToken)
link := fmt.Sprintf("%s/lc", common.EnvConfig.AppURL)
linkWithCode := fmt.Sprintf("%s/%s", link, oneTimeAccessToken)
// Add redirect path to the link
if strings.HasPrefix(redirectPath, "/") {
encodedRedirectPath := url.QueryEscape(redirectPath)
link = fmt.Sprintf("%s?redirect=%s", link, encodedRedirectPath)
linkWithCode = fmt.Sprintf("%s?redirect=%s", linkWithCode, encodedRedirectPath)
}
go func() {
@@ -141,7 +228,9 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin
Name: user.Username,
Email: user.Email,
}, OneTimeAccessTemplate, &OneTimeAccessTemplateData{
Link: link,
Code: oneTimeAccessToken,
LoginLink: link,
LoginLinkWithCode: linkWithCode,
})
if err != nil {
log.Printf("Failed to send email to '%s': %v\n", user.Email, err)
@@ -152,7 +241,14 @@ func (s *UserService) RequestOneTimeAccessEmail(emailAddress, redirectPath strin
}
func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(16)
tokenLength := 16
// If expires at is less than 15 minutes, use an 6 character token instead of 16
if time.Until(expiresAt) <= 15*time.Minute {
tokenLength = 6
}
randomString, err := utils.GenerateRandomAlphanumericString(tokenLength)
if err != nil {
return "", err
}
@@ -194,6 +290,33 @@ func (s *UserService) ExchangeOneTimeAccessToken(token string, ipAddress, userAg
return oneTimeAccessToken.User, accessToken, nil
}
func (s *UserService) UpdateUserGroups(id string, userGroupIds []string) (user model.User, err error) {
user, err = s.GetUser(id)
if err != nil {
return model.User{}, err
}
// Fetch the groups based on userGroupIds
var groups []model.UserGroup
if len(userGroupIds) > 0 {
if err := s.db.Where("id IN (?)", userGroupIds).Find(&groups).Error; err != nil {
return model.User{}, err
}
}
// Replace the current groups with the new set of groups
if err := s.db.Model(&user).Association("UserGroups").Replace(groups); err != nil {
return model.User{}, err
}
// Save the updated user
if err := s.db.Save(&user).Error; err != nil {
return model.User{}, err
}
return user, nil
}
func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
var userCount int64
if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil {
@@ -239,3 +362,27 @@ func (s *UserService) checkDuplicatedFields(user model.User) error {
return nil
}
// ResetProfilePicture deletes a user's custom profile picture
func (s *UserService) ResetProfilePicture(userID string) error {
// Validate the user ID to prevent directory traversal
if err := uuid.Validate(userID); err != nil {
return &common.InvalidUUIDError{}
}
// Build path to profile picture
profilePicturePath := common.EnvConfig.UploadPath + "/profile-pictures/" + userID + ".png"
// Check if file exists and delete it
if _, err := os.Stat(profilePicturePath); err == nil {
if err := os.Remove(profilePicturePath); err != nil {
return fmt.Errorf("failed to delete profile picture: %w", err)
}
} else if !os.IsNotExist(err) {
// If any error other than "file not exists"
return fmt.Errorf("failed to check if profile picture exists: %w", err)
}
// It's okay if the file doesn't exist - just means there's no custom picture to delete
return nil
}

View File

@@ -95,8 +95,11 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
return model.WebauthnCredential{}, err
}
// Determine passkey name using AAGUID and User-Agent
passkeyName := s.determinePasskeyName(credential.Authenticator.AAGUID)
credentialToStore := model.WebauthnCredential{
Name: "New Passkey",
Name: passkeyName,
CredentialID: credential.ID,
AttestationType: credential.AttestationType,
PublicKey: credential.PublicKey,
@@ -112,6 +115,16 @@ func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.R
return credentialToStore, nil
}
func (s *WebAuthnService) determinePasskeyName(aaguid []byte) string {
// First try to identify by AAGUID using a combination of builtin + MDS
authenticatorName := utils.GetAuthenticatorName(aaguid)
if authenticatorName != "" {
return authenticatorName
}
return "New Passkey" // Default fallback
}
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
options, session, err := s.webAuthn.BeginDiscoverableLogin()
if err != nil {

View File

@@ -0,0 +1,64 @@
package utils
import (
"encoding/hex"
"encoding/json"
"fmt"
"log"
"sync"
"github.com/pocket-id/pocket-id/backend/resources"
)
var (
aaguidMap map[string]string
aaguidMapOnce sync.Once
)
// FormatAAGUID converts an AAGUID byte slice to UUID string format
func FormatAAGUID(aaguid []byte) string {
if len(aaguid) == 0 {
return ""
}
// If exactly 16 bytes, format as UUID
if len(aaguid) == 16 {
return fmt.Sprintf("%x-%x-%x-%x-%x",
aaguid[0:4], aaguid[4:6], aaguid[6:8], aaguid[8:10], aaguid[10:16])
}
// Otherwise just return as hex
return hex.EncodeToString(aaguid)
}
// GetAuthenticatorName returns the name of the authenticator for the given AAGUID
func GetAuthenticatorName(aaguid []byte) string {
aaguidStr := FormatAAGUID(aaguid)
if aaguidStr == "" {
return ""
}
// Then check JSON-sourced map
aaguidMapOnce.Do(loadAAGUIDsFromFile)
if name, ok := aaguidMap[aaguidStr]; ok {
return name + " Passkey"
}
return ""
}
// loadAAGUIDsFromFile loads AAGUID data from the embedded file system
func loadAAGUIDsFromFile() {
// Read from embedded file system
data, err := resources.FS.ReadFile("aaguids.json")
if err != nil {
log.Printf("Error reading embedded AAGUID file: %v", err)
return
}
if err := json.Unmarshal(data, &aaguidMap); err != nil {
log.Printf("Error unmarshalling AAGUID data: %v", err)
return
}
}

View File

@@ -0,0 +1,124 @@
package utils
import (
"encoding/hex"
"sync"
"testing"
)
func TestFormatAAGUID(t *testing.T) {
tests := []struct {
name string
aaguid []byte
want string
}{
{
name: "empty byte slice",
aaguid: []byte{},
want: "",
},
{
name: "16 byte slice - standard UUID",
aaguid: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10},
want: "01020304-0506-0708-090a-0b0c0d0e0f10",
},
{
name: "non-16 byte slice",
aaguid: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
want: "0102030405",
},
{
name: "specific UUID example",
aaguid: mustDecodeHex("adce000235bcc60a648b0b25f1f05503"),
want: "adce0002-35bc-c60a-648b-0b25f1f05503",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := FormatAAGUID(tt.aaguid)
if got != tt.want {
t.Errorf("FormatAAGUID() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetAuthenticatorName(t *testing.T) {
// Reset the aaguidMap for testing
originalMap := aaguidMap
defer func() {
aaguidMap = originalMap
}()
// Inject a test AAGUID map
aaguidMap = map[string]string{
"adce0002-35bc-c60a-648b-0b25f1f05503": "Test Authenticator",
"00000000-0000-0000-0000-000000000000": "Zero Authenticator",
}
aaguidMapOnce = sync.Once{}
aaguidMapOnce.Do(func() {}) // Mark as done to avoid loading from file
tests := []struct {
name string
aaguid []byte
want string
}{
{
name: "empty byte slice",
aaguid: []byte{},
want: "",
},
{
name: "known AAGUID",
aaguid: mustDecodeHex("adce000235bcc60a648b0b25f1f05503"),
want: "Test Authenticator Passkey",
},
{
name: "zero UUID",
aaguid: mustDecodeHex("00000000000000000000000000000000"),
want: "Zero Authenticator Passkey",
},
{
name: "unknown AAGUID",
aaguid: mustDecodeHex("ffffffffffffffffffffffffffffffff"),
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetAuthenticatorName(tt.aaguid)
if got != tt.want {
t.Errorf("GetAuthenticatorName() = %v, want %v", got, tt.want)
}
})
}
}
func TestLoadAAGUIDsFromFile(t *testing.T) {
// Reset the map and once flag for clean testing
aaguidMap = nil
aaguidMapOnce = sync.Once{}
// Trigger loading of AAGUIDs by calling GetAuthenticatorName
GetAuthenticatorName([]byte{0x01, 0x02, 0x03, 0x04})
if len(aaguidMap) == 0 {
t.Error("loadAAGUIDsFromFile() failed to populate aaguidMap")
}
// Check for a few known entries that should be in the embedded file
// This test will be more brittle as it depends on the content of aaguids.json,
// but it helps verify that the loading actually worked
t.Log("AAGUID map loaded with", len(aaguidMap), "entries")
}
// Helper function to convert hex string to bytes
func mustDecodeHex(s string) []byte {
bytes, err := hex.DecodeString(s)
if err != nil {
panic("invalid hex in test: " + err.Error())
}
return bytes
}

View File

@@ -45,7 +45,11 @@ func genAddressHeader(name string, addresses []Address, maxLength int) string {
} else {
email = fmt.Sprintf("<%s>", addr.Email)
}
writeHeaderQ(hl, addr.Name)
if isPrintableASCII(addr.Name) {
writeHeaderAtom(hl, addr.Name)
} else {
writeHeaderQ(hl, addr.Name)
}
writeHeaderAtom(hl, " ")
writeHeaderAtom(hl, email)
}
@@ -166,15 +170,13 @@ func (c *Composer) String() string {
func convertRunes(str string) []string {
var enc = make([]string, 0, len(str))
for _, r := range []rune(str) {
if r == ' ' {
for _, r := range str {
switch {
case r == ' ':
enc = append(enc, "_")
} else if isPrintableASCIIRune(r) &&
r != '=' &&
r != '?' &&
r != '_' {
case isPrintableASCIIRune(r) && r != '=' && r != '?' && r != '_':
enc = append(enc, string(r))
} else {
default:
enc = append(enc, string(toHex([]byte(string(r)))))
}
}
@@ -200,7 +202,7 @@ func hex(n byte) byte {
}
func isPrintableASCII(str string) bool {
for _, r := range []rune(str) {
for _, r := range str {
if !unicode.IsPrint(r) || r >= unicode.MaxASCII {
return false
}

View File

@@ -27,7 +27,7 @@ func GetTemplate[U any, V any](templateMap TemplateMap[U], template Template[V])
return templateMap[template.Path]
}
type clonable[V pareseable[V]] interface {
type cloneable[V pareseable[V]] interface {
Clone() (V, error)
}
@@ -35,7 +35,7 @@ type pareseable[V any] interface {
ParseFS(fs.FS, ...string) (V, error)
}
func prepareTemplate[V pareseable[V]](templateFS fs.FS, template string, rootTemplate clonable[V], suffix string) (V, error) {
func prepareTemplate[V pareseable[V]](templateFS fs.FS, template string, rootTemplate cloneable[V], suffix string) (V, error) {
tmpl, err := rootTemplate.Clone()
if err != nil {
return *new(V), fmt.Errorf("clone root template: %w", err)

View File

@@ -1,18 +1,25 @@
package utils
import (
"errors"
"fmt"
"hash/crc64"
"io"
"mime/multipart"
"os"
"path/filepath"
"strings"
"strconv"
"time"
"github.com/pocket-id/pocket-id/backend/resources"
)
func GetFileExtension(filename string) string {
splitted := strings.Split(filename, ".")
return splitted[len(splitted)-1]
ext := filepath.Ext(filename)
if len(ext) > 0 && ext[0] == '.' {
return ext[1:]
}
return filename
}
func GetImageMimeType(ext string) string {
@@ -67,12 +74,80 @@ func SaveFile(file *multipart.FileHeader, dst string) error {
return err
}
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, src)
return err
return SaveFileStream(src, dst)
}
// SaveFileStream saves a stream to a file.
func SaveFileStream(r io.Reader, dstFileName string) error {
// Our strategy is to save to a separate file and then rename it to override the original file
// First, get a temp file name that doesn't exist already
var tmpFileName string
var i int64
for {
seed := strconv.FormatInt(time.Now().UnixNano()+i, 10)
suffix := crc64.Checksum([]byte(dstFileName+seed), crc64.MakeTable(crc64.ISO))
tmpFileName = dstFileName + "." + strconv.FormatUint(suffix, 10)
exists, err := FileExists(tmpFileName)
if err != nil {
return fmt.Errorf("failed to check if file '%s' exists: %w", tmpFileName, err)
}
if !exists {
break
}
i++
}
// Write to the temporary file
tmpFile, err := os.Create(tmpFileName)
if err != nil {
return fmt.Errorf("failed to open file '%s' for writing: %w", tmpFileName, err)
}
n, err := io.Copy(tmpFile, r)
if err != nil {
// Delete the temporary file; we ignore errors here
_ = tmpFile.Close()
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to write to file '%s': %w", tmpFileName, err)
}
err = tmpFile.Close()
if err != nil {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to close stream to file '%s': %w", tmpFileName, err)
}
if n == 0 {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return errors.New("no data written")
}
// Rename to the final file, which overrides existing files
// This is an atomic operation
err = os.Rename(tmpFileName, dstFileName)
if err != nil {
// Delete the temporary file; we ignore errors here
_ = os.Remove(tmpFileName)
return fmt.Errorf("failed to rename file '%s': %w", dstFileName, err)
}
return nil
}
// FileExists returns true if a file exists on disk and is a regular file
func FileExists(path string) (bool, error) {
s, err := os.Stat(path)
if err != nil {
if os.IsNotExist(err) {
err = nil
}
return false, err
}
return !s.IsDir(), nil
}

View File

@@ -0,0 +1,73 @@
package utils
import (
"testing"
)
func TestGetFileExtension(t *testing.T) {
tests := []struct {
name string
filename string
want string
}{
{
name: "Simple file with extension",
filename: "document.pdf",
want: "pdf",
},
{
name: "File with path",
filename: "/path/to/document.txt",
want: "txt",
},
{
name: "File with path (Windows style)",
filename: "C:\\path\\to\\document.jpg",
want: "jpg",
},
{
name: "Multiple extensions",
filename: "archive.tar.gz",
want: "gz",
},
{
name: "Hidden file with extension",
filename: ".config.json",
want: "json",
},
{
name: "Filename with dots",
filename: "version.1.2.3.txt",
want: "txt",
},
{
name: "File with uppercase extension",
filename: "image.JPG",
want: "JPG",
},
{
name: "File without extension",
filename: "README",
want: "README",
},
{
name: "Hidden file without extension",
filename: ".gitignore",
want: "gitignore",
},
{
name: "Empty filename",
filename: "",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GetFileExtension(tt.filename)
if got != tt.want {
t.Errorf("GetFileExtension(%q) = %q, want %q", tt.filename, got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,11 @@
package utils
import (
"crypto/sha256"
"encoding/hex"
)
func CreateSha256Hash(input string) string {
hash := sha256.Sum256([]byte(input))
return hex.EncodeToString(hash[:])
}

View File

@@ -0,0 +1,103 @@
package profilepicture
import (
"bytes"
"fmt"
"image"
"image/color"
"io"
"strings"
"github.com/disintegration/imageorient"
"github.com/disintegration/imaging"
"golang.org/x/image/font"
"golang.org/x/image/font/opentype"
"golang.org/x/image/math/fixed"
"github.com/pocket-id/pocket-id/backend/resources"
)
const profilePictureSize = 300
// CreateProfilePicture resizes the profile picture to a square
func CreateProfilePicture(file io.Reader) (io.Reader, error) {
img, _, err := imageorient.Decode(file)
if err != nil {
return nil, fmt.Errorf("failed to decode image: %w", err)
}
img = imaging.Fill(img, profilePictureSize, profilePictureSize, imaging.Center, imaging.Lanczos)
pr, pw := io.Pipe()
go func() {
err = imaging.Encode(pw, img, imaging.PNG)
if err != nil {
_ = pw.CloseWithError(fmt.Errorf("failed to encode image: %w", err))
return
}
pw.Close()
}()
return pr, nil
}
// CreateDefaultProfilePicture creates a profile picture with the initials
func CreateDefaultProfilePicture(firstName, lastName string) (*bytes.Buffer, error) {
// Get the initials
initials := ""
if len(firstName) > 0 {
initials += string(firstName[0])
}
if len(lastName) > 0 {
initials += string(lastName[0])
}
initials = strings.ToUpper(initials)
// Create a blank image with a white background
img := imaging.New(profilePictureSize, profilePictureSize, color.RGBA{R: 255, G: 255, B: 255, A: 255})
// Load the font
fontBytes, err := resources.FS.ReadFile("fonts/PlayfairDisplay-Bold.ttf")
if err != nil {
return nil, fmt.Errorf("failed to read font file: %w", err)
}
// Parse the font
fontFace, err := opentype.Parse(fontBytes)
if err != nil {
return nil, fmt.Errorf("failed to parse font: %w", err)
}
// Create a font.Face with a specific size
fontSize := 160.0
face, err := opentype.NewFace(fontFace, &opentype.FaceOptions{
Size: fontSize,
DPI: 72,
})
if err != nil {
return nil, fmt.Errorf("failed to create font face: %w", err)
}
// Create a drawer for the image
drawer := &font.Drawer{
Dst: img,
Src: image.NewUniform(color.RGBA{R: 0, G: 0, B: 0, A: 255}), // Black text color
Face: face,
}
// Center the initials
x := (profilePictureSize - font.MeasureString(face, initials).Ceil()) / 2
y := (profilePictureSize-face.Metrics().Height.Ceil())/2 + face.Metrics().Ascent.Ceil() - 10
drawer.Dot = fixed.P(x, y)
// Draw the initials
drawer.DrawString(initials)
var buf bytes.Buffer
err = imaging.Encode(&buf, img, imaging.PNG)
if err != nil {
return nil, fmt.Errorf("failed to encode image: %w", err)
}
return &buf, nil
}

View File

@@ -1,8 +1,10 @@
package utils
import (
"gorm.io/gorm"
"reflect"
"strconv"
"gorm.io/gorm"
)
type PaginationResponse struct {
@@ -30,7 +32,7 @@ func PaginateAndSort(sortedPaginationRequest SortedPaginationRequest, query *gor
capitalizedSortColumn := CapitalizeFirstLetter(sort.Column)
sortField, sortFieldFound := reflect.TypeOf(result).Elem().Elem().FieldByName(capitalizedSortColumn)
isSortable := sortField.Tag.Get("sortable") == "true"
isSortable, _ := strconv.ParseBool(sortField.Tag.Get("sortable"))
isValidSortOrder := sort.Direction == "asc" || sort.Direction == "desc"
if sortFieldFound && isSortable && isValidSortOrder {
@@ -47,7 +49,7 @@ func Paginate(page int, pageSize int, query *gorm.DB, result interface{}) (Pagin
}
if pageSize < 1 {
pageSize = 10
pageSize = 20
} else if pageSize > 100 {
pageSize = 100
}

View File

@@ -2,8 +2,9 @@ package utils
import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"io"
"net/url"
"regexp"
"strings"
@@ -13,23 +14,41 @@ import (
// GenerateRandomAlphanumericString generates a random alphanumeric string of the given length
func GenerateRandomAlphanumericString(length int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
const charsetLength = int64(len(charset))
if length <= 0 {
return "", fmt.Errorf("length must be a positive integer")
return "", errors.New("length must be a positive integer")
}
result := make([]byte, length)
// The algorithm below is adapted from https://stackoverflow.com/a/35615565
const (
letterIdxBits = 6 // 6 bits to represent a letter index
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
)
for i := range result {
num, err := rand.Int(rand.Reader, big.NewInt(charsetLength))
if err != nil {
return "", err
result := strings.Builder{}
result.Grow(length)
// Because we discard a bunch of bytes, we read more in the buffer to minimize the changes of performing additional IO
bufferSize := int(float64(length) * 1.3)
randomBytes := make([]byte, bufferSize)
for i, j := 0, 0; i < length; j++ {
// Fill the buffer if needed
if j%bufferSize == 0 {
_, err := io.ReadFull(rand.Reader, randomBytes)
if err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
}
// Discard bytes that are outside of the range
// This allows making sure that we maintain uniform distribution
idx := int(randomBytes[j%length] & letterIdxMask)
if idx < len(charset) {
result.WriteByte(charset[idx])
i++
}
result[i] = charset[num.Int64()]
}
return string(result), nil
return result.String(), nil
}
func GetHostnameFromURL(rawURL string) string {
@@ -45,30 +64,40 @@ func StringPointer(s string) *string {
return &s
}
func CapitalizeFirstLetter(s string) string {
if s == "" {
return s
func CapitalizeFirstLetter(str string) string {
if str == "" {
return ""
}
runes := []rune(s)
runes[0] = unicode.ToUpper(runes[0])
return string(runes)
result := strings.Builder{}
result.Grow(len(str))
for i, r := range str {
if i == 0 {
result.WriteRune(unicode.ToUpper(r))
} else {
result.WriteRune(r)
}
}
return result.String()
}
func CamelCaseToSnakeCase(s string) string {
var result []rune
for i, r := range s {
func CamelCaseToSnakeCase(str string) string {
result := strings.Builder{}
result.Grow(int(float32(len(str)) * 1.1))
for i, r := range str {
if unicode.IsUpper(r) && i > 0 {
result = append(result, '_')
result.WriteByte('_')
}
result = append(result, unicode.ToLower(r))
result.WriteRune(unicode.ToLower(r))
}
return string(result)
return result.String()
}
var camelCaseToScreamingSnakeCaseRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
func CamelCaseToScreamingSnakeCase(s string) string {
// Insert underscores before uppercase letters (except the first one)
re := regexp.MustCompile(`([a-z0-9])([A-Z])`)
snake := re.ReplaceAllString(s, `${1}_${2}`)
snake := camelCaseToScreamingSnakeCaseRe.ReplaceAllString(s, `${1}_${2}`)
// Convert to uppercase
return strings.ToUpper(snake)

View File

@@ -0,0 +1,105 @@
package utils
import (
"regexp"
"testing"
)
func TestGenerateRandomAlphanumericString(t *testing.T) {
t.Run("valid length returns correct string", func(t *testing.T) {
const length = 10
str, err := GenerateRandomAlphanumericString(length)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if len(str) != length {
t.Errorf("Expected length %d, got %d", length, len(str))
}
matched, err := regexp.MatchString(`^[a-zA-Z0-9]+$`, str)
if err != nil {
t.Errorf("Regex match failed: %v", err)
}
if !matched {
t.Errorf("String contains non-alphanumeric characters: %s", str)
}
})
t.Run("zero length returns error", func(t *testing.T) {
_, err := GenerateRandomAlphanumericString(0)
if err == nil {
t.Error("Expected error for zero length, got nil")
}
})
t.Run("negative length returns error", func(t *testing.T) {
_, err := GenerateRandomAlphanumericString(-1)
if err == nil {
t.Error("Expected error for negative length, got nil")
}
})
t.Run("generates different strings", func(t *testing.T) {
str1, _ := GenerateRandomAlphanumericString(10)
str2, _ := GenerateRandomAlphanumericString(10)
if str1 == str2 {
t.Error("Generated strings should be different")
}
})
}
func TestCapitalizeFirstLetter(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"empty string", "", ""},
{"lowercase first letter", "hello", "Hello"},
{"already capitalized", "Hello", "Hello"},
{"single lowercase letter", "h", "H"},
{"single uppercase letter", "H", "H"},
{"starts with number", "123abc", "123abc"},
{"unicode character", "étoile", "Étoile"},
{"special character", "_test", "_test"},
{"multi-word", "hello world", "Hello world"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := CapitalizeFirstLetter(tt.input)
if result != tt.expected {
t.Errorf("CapitalizeFirstLetter(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestCamelCaseToSnakeCase(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"empty string", "", ""},
{"simple camelCase", "camelCase", "camel_case"},
{"PascalCase", "PascalCase", "pascal_case"},
{"multipleWordsInCamelCase", "multipleWordsInCamelCase", "multiple_words_in_camel_case"},
{"consecutive uppercase", "HTTPRequest", "h_t_t_p_request"},
{"single lowercase word", "word", "word"},
{"single uppercase word", "WORD", "w_o_r_d"},
{"with numbers", "camel123Case", "camel123_case"},
{"with numbers in middle", "model2Name", "model2_name"},
{"mixed case", "iPhone6sPlus", "i_phone6s_plus"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := CamelCaseToSnakeCase(tt.input)
if result != tt.expected {
t.Errorf("CamelCaseToSnakeCase(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,33 @@
package systemd
import (
"net"
"os"
)
// SdNotifyReady sends a message to the systemd daemon to notify that service is ready to operate.
// It is common to ignore the error.
func SdNotifyReady() error {
socketAddr := &net.UnixAddr{
Name: os.Getenv("NOTIFY_SOCKET"),
Net: "unixgram",
}
if socketAddr.Name == "" {
return nil
}
conn, err := net.DialUnix(socketAddr.Net, nil, socketAddr)
if err != nil {
return err
}
defer func() {
_ = conn.Close()
}()
if _, err = conn.Write([]byte("READY=1")); err != nil {
return err
}
return nil
}

BIN
backend/main Executable file

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -1,95 +1,92 @@
{{ define "style" }}
<style>
body {
font-family: Arial, sans-serif;
background-color: #f0f0f0;
color: #333;
margin: 0;
padding: 0;
}
.container {
background-color: #fff;
color: #333;
padding: 32px;
border-radius: 10px;
max-width: 600px;
margin: 40px auto;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 24px;
}
.header .logo {
display: flex;
align-items: center;
gap: 8px;
}
.header .logo img {
width: 32px;
height: 32px;
object-fit: cover;
}
.header h1 {
font-size: 1.5rem;
font-weight: bold;
}
.warning {
background-color: #ffd966;
color: #7f6000;
padding: 4px 12px;
border-radius: 50px;
font-size: 0.875rem;
}
.content {
background-color: #fafafa;
color: #333;
padding: 24px;
border-radius: 10px;
}
.content h2 {
font-size: 1.25rem;
font-weight: bold;
margin-bottom: 16px;
}
.grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 16px;
margin-bottom: 16px;
}
.grid div {
display: flex;
flex-direction: column;
}
.grid p {
margin: 0;
}
.label {
color: #888;
font-size: 0.875rem;
margin-bottom: 4px;
}
.message {
font-size: 1rem;
line-height: 1.5;
}
.button {
border-radius: 0.375rem;
font-size: 1rem;
font-weight: 500;
background-color: #000000;
color: #ffffff;
padding: 0.7rem 1.5rem;
outline: none;
border: none;
text-decoration: none;
}
.button-container {
text-align: center;
margin-top: 24px;
}
/* Reset styles for email clients */
body, table, td, p, a {
margin: 0;
padding: 0;
border: 0;
font-size: 100%;
font-family: Arial, sans-serif;
line-height: 1.5;
}
body {
background-color: #f0f0f0;
color: #333;
}
.container {
width: 100%;
max-width: 600px;
margin: 40px auto;
background-color: #fff;
border-radius: 10px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
padding: 32px;
}
.header {
display: flex;
margin-bottom: 24px;
}
.header .logo img {
width: 32px;
height: 32px;
vertical-align: middle;
}
.header h1 {
font-size: 1.5rem;
font-weight: bold;
display: inline-block;
vertical-align: middle;
margin-left: 8px;
}
.warning {
background-color: #ffd966;
color: #7f6000;
padding: 4px 12px;
border-radius: 50px;
font-size: 0.875rem;
margin: auto 0 auto auto;
}
.content {
background-color: #fafafa;
padding: 24px;
border-radius: 10px;
}
.content h2 {
font-size: 1.25rem;
font-weight: bold;
margin-bottom: 16px;
}
.grid {
width: 100%;
margin-bottom: 16px;
}
.grid td {
width: 50%;
padding-bottom: 8px;
vertical-align: top;
}
.label {
color: #888;
font-size: 0.875rem;
}
.message {
font-size: 1rem;
line-height: 1.5;
margin-top: 16px;
}
.button {
background-color: #000000;
color: #ffffff;
padding: 0.7rem 1.5rem;
text-decoration: none;
border-radius: 4px;
font-size: 1rem;
font-weight: 500;
display: inline-block;
margin-top: 24px;
}
.button-container {
text-align: center;
}
</style>
{{ end }}

View File

@@ -1,36 +1,40 @@
{{ define "base" }}
<div class="header">
<div class="logo">
<img src="{{ .LogoURL }}" alt="{{ .AppName }}"/>
<h1>{{ .AppName }}</h1>
</div>
<div class="warning">Warning</div>
</div>
<div class="content">
<h2>New Sign-In Detected</h2>
<div class="grid">
{{ if and .Data.City .Data.Country }}
<div>
<p class="label">Approximate Location</p>
<p>{{ .Data.City }}, {{ .Data.Country }}</p>
</div>
{{ end }}
<div>
<p class="label">IP Address</p>
<p>{{ .Data.IPAddress }}</p>
</div>
<div>
<p class="label">Device</p>
<p>{{ .Data.Device }}</p>
</div>
<div>
<p class="label">Sign-In Time</p>
<p>{{ .Data.DateTime.Format "2006-01-02 15:04:05 UTC" }}</p>
</div>
</div>
<p class="message">
This sign-in was detected from a new device or location. If you recognize this activity, you can
safely ignore this message. If not, please review your account and security settings.
</p>
</div>
<div class="header">
<div class="logo">
<img src="{{ .LogoURL }}" alt="{{ .AppName }}" width="32" height="32" style="width: 32px; height: 32px; max-width: 32px;"/>
<h1>{{ .AppName }}</h1>
</div>
<div class="warning">Warning</div>
</div>
<div class="content">
<h2>New Sign-In Detected</h2>
<table class="grid">
<tr>
{{ if and .Data.City .Data.Country }}
<td>
<p class="label">Approximate Location</p>
<p>{{ .Data.City }}, {{ .Data.Country }}</p>
</td>
{{ end }}
<td>
<p class="label">IP Address</p>
<p>{{ .Data.IPAddress }}</p>
</td>
</tr>
<tr>
<td>
<p class="label">Device</p>
<p>{{ .Data.Device }}</p>
</td>
<td>
<p class="label">Sign-In Time</p>
<p>{{ .Data.DateTime.Format "2006-01-02 15:04:05 UTC" }}</p>
</td>
</tr>
</table>
<p class="message">
This sign-in was detected from a new device or location. If you recognize this activity, you can
safely ignore this message. If not, please review your account and security settings.
</p>
</div>
{{ end -}}

View File

@@ -1,17 +1,17 @@
{{ define "base" }}
<div class="header">
<div class="logo">
<img src="{{ .LogoURL }}" alt="{{ .AppName }}"/>
<img src="{{ .LogoURL }}" alt="{{ .AppName }}" width="32" height="32" style="width: 32px; height: 32px; max-width: 32px;"/>
<h1>{{ .AppName }}</h1>
</div>
</div>
<div class="content">
<h2>One-Time Access</h2>
<h2>Login Code</h2>
<p class="message">
Click the button below to sign in to {{ .AppName }} with a one-time access link. This link expires in 15 minutes.
Click the button below to sign in to {{ .AppName }} with a login code.</br>Or visit <a href="{{ .Data.LoginLink }}">{{ .Data.LoginLink }}</a> and enter the code <strong>{{ .Data.Code }}</strong>.</br></br>This code expires in 15 minutes.
</p>
<div class="button-container">
<a class="button" href="{{ .Data.Link }}" class="button">Sign In</a>
<a class="button" href="{{ .Data.LoginLinkWithCode }}" class="button">Sign In</a>
</div>
</div>
{{ end -}}

Some files were not shown because too many files have changed in this diff Show More