mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-08 01:10:19 +03:00
Compare commits
21 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3fe24a04de | ||
|
|
6769cc8c10 | ||
|
|
97f7fc4e28 | ||
|
|
fc47c2a2a4 | ||
|
|
f1a6c8db85 | ||
|
|
552d7ccfa5 | ||
|
|
e45b0b3ed0 | ||
|
|
8166e2ead7 | ||
|
|
ae7aeb0945 | ||
|
|
16f273ffce | ||
|
|
9f49e5577e | ||
|
|
d92b80b80f | ||
|
|
aaed71e1c8 | ||
|
|
4780548843 | ||
|
|
a5dfdd2178 | ||
|
|
9eec7a3e9e | ||
|
|
fdc1921f5d | ||
|
|
601f6c488a | ||
|
|
0595d73ea5 | ||
|
|
74f4c22800 | ||
|
|
b49063d692 |
35
CHANGELOG.md
35
CHANGELOG.md
@@ -1,3 +1,38 @@
|
||||
## [](https://github.com/stonith404/pocket-id/compare/v0.3.0...v) (2024-08-24)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* empty lists don't get returned correctly from the api ([97f7fc4](https://github.com/stonith404/pocket-id/commit/97f7fc4e288c2bb49210072a7a151b58ef44f5b5))
|
||||
|
||||
## [](https://github.com/stonith404/pocket-id/compare/v0.2.1...v) (2024-08-23)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* add support for multiple callback urls ([8166e2e](https://github.com/stonith404/pocket-id/commit/8166e2ead7fc71a0b7a45950b05c5c65a60833b6))
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* db migration for multiple callback urls ([552d7cc](https://github.com/stonith404/pocket-id/commit/552d7ccfa58d7922ecb94bdfe6a86651b4cf2745))
|
||||
|
||||
## [](https://github.com/stonith404/pocket-id/compare/v0.2.0...v) (2024-08-19)
|
||||
|
||||
|
||||
### Bug Fixes
|
||||
|
||||
* session duration can't be updated ([4780548](https://github.com/stonith404/pocket-id/commit/478054884389ed8a08d707fd82da7b31177a67e5))
|
||||
|
||||
## [](https://github.com/stonith404/pocket-id/compare/v0.1.3...v) (2024-08-19)
|
||||
|
||||
|
||||
### Features
|
||||
|
||||
* add `INTERNAL_BACKEND_URL` env variable ([0595d73](https://github.com/stonith404/pocket-id/commit/0595d73ea5afbd7937b8f292ffe624139f818f41))
|
||||
* add user info endpoint to support more oidc clients ([fdc1921](https://github.com/stonith404/pocket-id/commit/fdc1921f5dcb5ac6beef8d1c9b1b7c53f514cce5))
|
||||
* change default logo ([9eec7a3](https://github.com/stonith404/pocket-id/commit/9eec7a3e9eb7f690099f38a5d4cf7c2516ea9ef9))
|
||||
|
||||
## [](https://github.com/stonith404/pocket-id/compare/v0.1.2...v) (2024-08-13)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ RUN npm run build
|
||||
RUN npm prune --production
|
||||
|
||||
# Stage 2: Build Backend
|
||||
FROM golang:1.22-alpine AS backend-builder
|
||||
FROM golang:1.23-alpine AS backend-builder
|
||||
WORKDIR /app/backend
|
||||
COPY ./backend/go.mod ./backend/go.sum ./
|
||||
RUN go mod download
|
||||
|
||||
81
README.md
81
README.md
@@ -1,16 +1,22 @@
|
||||
# <div align="center"><img src="https://github.com/user-attachments/assets/5b5e0d42-e2b4-4523-add5-ac87042a72f1" width="100"/> </br>Pocket ID</div>
|
||||
# <div align="center"><img src="https://github.com/user-attachments/assets/4ceb2708-9f29-4694-b797-be833efce17d" width="100"/> </br>Pocket ID</div>
|
||||
|
||||
Pocket ID is a simple OIDC provider that allows users to authenticate with their passkeys to your services.
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/953c534c-b667-44e5-b976-a59142f1efb8" width="1200"/>
|
||||
<img src="https://github.com/user-attachments/assets/96ac549d-b897-404a-8811-f42b16ea58e2" width="1200"/>
|
||||
|
||||
The goal of Pocket ID is to be a simple and easy-to-use. There are other self-hosted OIDC providers like [Keycloak](https://www.keycloak.org/) or [ORY Hydra](https://www.ory.sh/hydra/) but they are often too complex for simple use cases. Additionally, Pocket ID only support passkey authentication which is a passwordless authentication method.
|
||||
The goal of Pocket ID is to be a simple and easy-to-use. There are other self-hosted OIDC providers like [Keycloak](https://www.keycloak.org/) or [ORY Hydra](https://www.ory.sh/hydra/) but they are often too complex for simple use cases.
|
||||
|
||||
Additionally, what makes Pocket ID special is that it only supports [passkey](https://www.passkeys.io/) authentication, which means you don’t need a password. Some people might not like this idea at first, but I believe passkeys are the future, and once you try them, you’ll love them. For example, you can now use a physical Yubikey to sign in to all your self-hosted services easily and securely.
|
||||
|
||||
## Setup
|
||||
|
||||
> [!WARNING]
|
||||
> Pocket ID is in its early stages and may contain bugs.
|
||||
|
||||
### Before you start
|
||||
|
||||
Pocket ID requires a [secure context](https://developer.mozilla.org/en-US/docs/Web/Security/Secure_Contexts), meaning it must be served over HTTPS. This is necessary because Pocket ID uses the [WebAuthn API](https://developer.mozilla.org/en-US/docs/Web/API/Web_Authentication_API) which requires a secure context.
|
||||
|
||||
### Installation with Docker (recommended)
|
||||
|
||||
1. Download the `docker-compose.yml` and `.env` file:
|
||||
@@ -26,47 +32,52 @@ The goal of Pocket ID is to be a simple and easy-to-use. There are other self-ho
|
||||
|
||||
You can now sign in with the admin account on `http://localhost/login/setup`.
|
||||
|
||||
### Unraid
|
||||
|
||||
Pocket ID is available as a template on the Community Apps store.
|
||||
|
||||
### Stand-alone Installation
|
||||
|
||||
Required tools:
|
||||
|
||||
- [Node.js](https://nodejs.org/en/download/) >= 20
|
||||
- [Go](https://golang.org/doc/install) >= 1.22
|
||||
- [Go](https://golang.org/doc/install) >= 1.23
|
||||
- [Git](https://git-scm.com/downloads)
|
||||
- [PM2](https://pm2.keymetrics.io/)
|
||||
- [Caddy](https://caddyserver.com/docs/install) (optional)
|
||||
|
||||
1. Copy the `.env.example` file in the `frontend` and `backend` folder to `.env` and change it so that it fits your needs.
|
||||
|
||||
|
||||
```bash
|
||||
cp frontend/.env.example frontend/.env
|
||||
cp backend/.env.example backend/.env
|
||||
```
|
||||
|
||||
2. Run the following commands:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/stonith404/pocket-id
|
||||
cd pocket-id
|
||||
```bash
|
||||
git clone https://github.com/stonith404/pocket-id
|
||||
cd pocket-id
|
||||
|
||||
# Checkout the latest version
|
||||
git fetch --tags && git checkout $(git describe --tags `git rev-list --tags --max-count=1`)
|
||||
# Checkout the latest version
|
||||
git fetch --tags && git checkout $(git describe --tags `git rev-list --tags --max-count=1`)
|
||||
|
||||
# Start the backend
|
||||
cd backend/cmd
|
||||
go build -o ../pocket-id-backend
|
||||
cd ..
|
||||
pm2 start pocket-id-backend --name pocket-id-backend
|
||||
# Start the backend
|
||||
cd backend/cmd
|
||||
go build -o ../pocket-id-backend
|
||||
cd ..
|
||||
pm2 start pocket-id-backend --name pocket-id-backend
|
||||
|
||||
# Start the frontend
|
||||
cd ../frontend
|
||||
npm install
|
||||
npm run build
|
||||
pm2 start --name pocket-id-frontend --node-args="--env-file .env" build/index.js
|
||||
# Start the frontend
|
||||
cd ../frontend
|
||||
npm install
|
||||
npm run build
|
||||
pm2 start --name pocket-id-frontend --node-args="--env-file .env" build/index.js
|
||||
|
||||
# Optional: Start Caddy (You can use any other reverse proxy)
|
||||
cd ..
|
||||
pm2 start --name pocket-id-caddy caddy -- run --config Caddyfile
|
||||
```
|
||||
# Optional: Start Caddy (You can use any other reverse proxy)
|
||||
cd ..
|
||||
pm2 start --name pocket-id-caddy caddy -- run --config Caddyfile
|
||||
```
|
||||
|
||||
You can now sign in with the admin account on `http://localhost/login/setup`.
|
||||
|
||||
@@ -80,10 +91,17 @@ You may need the following information:
|
||||
|
||||
- **Authorization URL**: `https://<your-domain>/authorize`
|
||||
- **Token URL**: `https://<your-domain>/api/oidc/token`
|
||||
- **Userinfo URL**: `https://<your-domain>/api/oidc/userinfo`
|
||||
- **Certificate URL**: `https://<your-domain>/.well-known/jwks.json`
|
||||
- **OIDC Discovery URL**: `https://<your-domain>/.well-known/openid-configuration`
|
||||
- **PKCE**: `false` as this is not supported yet.
|
||||
|
||||
### Proxy Services with Pocket ID
|
||||
|
||||
As the goal of Pocket ID is to stay simple, we don't have a built-in proxy provider. However, you can use [OAuth2 Proxy](https://oauth2-proxy.github.io/) to add authentication to your services that don't support OIDC.
|
||||
|
||||
See the [guide](docs/proxy-services.md) for more information.
|
||||
|
||||
### Update
|
||||
|
||||
#### Docker
|
||||
@@ -126,13 +144,14 @@ docker compose up -d
|
||||
|
||||
### Environment variables
|
||||
|
||||
| Variable | Default Value | Recommended to change | Description |
|
||||
| ---------------- | ------------------- | --------------------- | --------------------------------------------- |
|
||||
| `PUBLIC_APP_URL` | `http://localhost` | yes | The URL where you will access the app. |
|
||||
| `DB_PATH` | `data/pocket-id.db` | no | The path to the SQLite database. |
|
||||
| `UPLOAD_PATH` | `data/uploads` | no | The path where the uploaded files are stored. |
|
||||
| `PORT` | `3000` | no | The port on which the frontend should listen. |
|
||||
| `BACKEND_PORT` | `8080` | no | The port on which the backend should listen. |
|
||||
| Variable | Default Value | Recommended to change | Description |
|
||||
| ---------------------- | ----------------------- | --------------------- | --------------------------------------------- |
|
||||
| `PUBLIC_APP_URL` | `http://localhost` | yes | The URL where you will access the app. |
|
||||
| `DB_PATH` | `data/pocket-id.db` | no | The path to the SQLite database. |
|
||||
| `UPLOAD_PATH` | `data/uploads` | no | The path where the uploaded files are stored. |
|
||||
| `INTERNAL_BACKEND_URL` | `http://localhost:8080` | no | The URL where the backend is accessible. |
|
||||
| `PORT` | `3000` | no | The port on which the frontend should listen. |
|
||||
| `BACKEND_PORT` | `8080` | no | The port on which the backend should listen. |
|
||||
|
||||
## Contribute
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"golang-rest-api-template/internal/bootstrap"
|
||||
"github.com/stonith404/pocket-id/backend/internal/bootstrap"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
module golang-rest-api-template
|
||||
module github.com/stonith404/pocket-id/backend
|
||||
|
||||
go 1.22
|
||||
go 1.23
|
||||
|
||||
require (
|
||||
github.com/caarlos0/env/v11 v11.2.0
|
||||
github.com/caarlos0/env/v11 v11.2.2
|
||||
github.com/fxamacker/cbor/v2 v2.7.0
|
||||
github.com/gin-contrib/cors v1.7.2
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/go-co-op/gocron/v2 v2.11.0
|
||||
github.com/go-webauthn/webauthn v0.11.0
|
||||
github.com/go-playground/validator/v10 v10.22.0
|
||||
github.com/go-webauthn/webauthn v0.11.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/golang-migrate/migrate/v4 v4.17.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
golang.org/x/crypto v0.25.0
|
||||
golang.org/x/crypto v0.26.0
|
||||
golang.org/x/time v0.6.0
|
||||
gorm.io/driver/sqlite v1.5.6
|
||||
gorm.io/gorm v1.25.11
|
||||
@@ -28,7 +29,6 @@ require (
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.22.0 // indirect
|
||||
github.com/go-webauthn/x v0.1.12 // indirect
|
||||
github.com/goccy/go-json v0.10.3 // indirect
|
||||
github.com/google/go-tpm v0.9.1 // indirect
|
||||
@@ -57,7 +57,7 @@ require (
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect
|
||||
golang.org/x/net v0.27.0 // indirect
|
||||
golang.org/x/sys v0.23.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/text v0.17.0 // indirect
|
||||
google.golang.org/protobuf v1.34.2 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -3,8 +3,8 @@ github.com/bytedance/sonic v1.12.1/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKz
|
||||
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM=
|
||||
github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
|
||||
github.com/caarlos0/env/v11 v11.2.0 h1:kvB1ZmwdWgI3JsuuVUE7z4cY/6Ujr03D0w2WkOOH4Xs=
|
||||
github.com/caarlos0/env/v11 v11.2.0/go.mod h1:LwgkYk1kDvfGpHthrWWLof3Ny7PezzFwS4QrsJdHTMo=
|
||||
github.com/caarlos0/env/v11 v11.2.2 h1:95fApNrUyueipoZN/EhA8mMxiNxrBwDa+oAZrMWl3Kg=
|
||||
github.com/caarlos0/env/v11 v11.2.2/go.mod h1:JBfcdeQiBoI3Zh1QRAWfe+tpiNTmDtcCj/hHHHMx0vc=
|
||||
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
|
||||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||
@@ -33,8 +33,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.22.0 h1:k6HsTZ0sTnROkhS//R0O+55JgM8C4Bx7ia+JlgcnOao=
|
||||
github.com/go-playground/validator/v10 v10.22.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/go-webauthn/webauthn v0.11.0 h1:2U0jWuGeoiI+XSZkHPFRtwaYtqmMUsqABtlfSq1rODo=
|
||||
github.com/go-webauthn/webauthn v0.11.0/go.mod h1:57ZrqsZzD/eboQDVtBkvTdfqFYAh/7IwzdPT+sPWqB0=
|
||||
github.com/go-webauthn/webauthn v0.11.1 h1:5G/+dg91/VcaJHTtJUfwIlNJkLwbJCcnUc4W8VtkpzA=
|
||||
github.com/go-webauthn/webauthn v0.11.1/go.mod h1:YXRm1WG0OtUyDFaVAgB5KG7kVqW+6dYCJ7FTQH4SxEE=
|
||||
github.com/go-webauthn/x v0.1.12 h1:RjQ5cvApzyU/xLCiP+rub0PE4HBZsLggbxGR5ZpUf/A=
|
||||
github.com/go-webauthn/x v0.1.12/go.mod h1:XlRcGkNH8PT45TfeJYc6gqpOtiOendHhVmnOxh+5yHs=
|
||||
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
|
||||
@@ -122,8 +122,8 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
golang.org/x/arch v0.9.0 h1:ub9TgUInamJ8mrZIGlBG6/4TqWeMszd4N8lNorbrr6k=
|
||||
golang.org/x/arch v0.9.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
|
||||
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
|
||||
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
|
||||
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
|
||||
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
|
||||
@@ -132,8 +132,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
|
||||
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc=
|
||||
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
|
||||
golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 15 KiB After Width: | Height: | Size: 15 KiB |
@@ -1,6 +1,6 @@
|
||||
<svg id="a"
|
||||
xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1015 1015">
|
||||
<path id="a" d="M838.28,380.27c-13.38-135.36-124.37-245.08-263.77-257.28h-299.08c-26.25,0-47.57,21.21-47.7,47.46-1.2,233.2-2.39,466.4-3.59,699.61-.14,26.44,21.26,47.95,47.7,47.95h102.86c24.66,0,45.25-18.79,47.5-43.35,11.45-124.75,22.89-249.51,34.34-374.26-43.84-29.69-66.46-88.34-40.37-148.47,10.44-24.06,29.57-43.41,53.58-53.98,86.02-37.84,169.22,24.14,169.22,105.56,0,40.38-20.47,75.98-51.6,96.99,6.71,56.71,13.42,113.43,20.14,170.14,1.2,10.14,11.21,16.74,21.03,13.93,134.16-38.42,223.25-167.79,209.75-304.31Z"/>
|
||||
<path d="M506.6,0c209.52,0,379.98,170.45,379.98,379.96,0,82.33-25.9,160.68-74.91,226.54-48.04,64.59-113.78,111.51-190.13,135.71l-21.1,6.7-50.29-248.04,13.91-6.73c45.41-21.95,74.76-68.71,74.76-119.11,0-72.91-59.31-132.23-132.21-132.23s-132.23,59.32-132.23,132.23c0,50.4,29.36,97.16,74.77,119.11l13.65,6.61-81.01,499.24h-226.36V0h351.18Z" />
|
||||
<style>
|
||||
@media (prefers-color-scheme: dark) {
|
||||
#a path {
|
||||
|
||||
|
Before Width: | Height: | Size: 871 B After Width: | Height: | Size: 696 B |
28
backend/internal/bootstrap/application_images_bootstrap.go
Normal file
28
backend/internal/bootstrap/application_images_bootstrap.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
func initApplicationImages() {
|
||||
dirPath := common.EnvConfig.UploadPath + "/application-images"
|
||||
|
||||
files, err := os.ReadDir(dirPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Fatalf("Error reading directory: %v", err)
|
||||
}
|
||||
|
||||
// Skip if files already exist
|
||||
if len(files) > 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// Copy files from source to destination
|
||||
err = utils.CopyDirectory("./images", dirPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Error copying directory: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,78 +1,16 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/common/middleware"
|
||||
"golang-rest-api-template/internal/handler"
|
||||
"golang-rest-api-template/internal/job"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"golang.org/x/time/rate"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
"github.com/stonith404/pocket-id/backend/internal/job"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
)
|
||||
|
||||
func Bootstrap() {
|
||||
common.InitDatabase()
|
||||
common.InitDbConfig()
|
||||
db := newDatabase()
|
||||
appConfigService := service.NewAppConfigService(db)
|
||||
|
||||
initApplicationImages()
|
||||
job.RegisterJobs()
|
||||
initRouter()
|
||||
}
|
||||
|
||||
func initRouter() {
|
||||
switch common.EnvConfig.AppEnv {
|
||||
case "production":
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
case "development":
|
||||
gin.SetMode(gin.DebugMode)
|
||||
case "test":
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
r := gin.Default()
|
||||
|
||||
r.Use(gin.Logger())
|
||||
|
||||
r.Use(middleware.Cors())
|
||||
r.Use(middleware.RateLimiter(rate.Every(time.Second), 60))
|
||||
|
||||
apiGroup := r.Group("/api")
|
||||
handler.RegisterRoutes(apiGroup)
|
||||
handler.RegisterOIDCRoutes(apiGroup)
|
||||
handler.RegisterUserRoutes(apiGroup)
|
||||
handler.RegisterConfigurationRoutes(apiGroup)
|
||||
if common.EnvConfig.AppEnv != "production" {
|
||||
handler.RegisterTestRoutes(apiGroup)
|
||||
}
|
||||
|
||||
baseGroup := r.Group("/")
|
||||
handler.RegisterWellKnownRoutes(baseGroup)
|
||||
|
||||
if err := r.Run(common.EnvConfig.Host + ":" + common.EnvConfig.Port); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func initApplicationImages() {
|
||||
dirPath := common.EnvConfig.UploadPath + "/application-images"
|
||||
|
||||
files, err := os.ReadDir(dirPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Fatalf("Error reading directory: %v", err)
|
||||
}
|
||||
|
||||
// Skip if files already exist
|
||||
if len(files) > 1 {
|
||||
return
|
||||
}
|
||||
|
||||
// Copy files from source to destination
|
||||
err = utils.CopyDirectory("./images", dirPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Error copying directory: %v", err)
|
||||
}
|
||||
job.RegisterJobs(db)
|
||||
initRouter(db, appConfigService)
|
||||
}
|
||||
|
||||
@@ -1,51 +1,54 @@
|
||||
package common
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
func InitDatabase() {
|
||||
connectDatabase()
|
||||
sqlDb, err := DB.DB()
|
||||
func newDatabase() (db *gorm.DB) {
|
||||
db, err := connectDatabase()
|
||||
if err != nil {
|
||||
log.Fatal("failed to get sql db", err)
|
||||
log.Fatalf("failed to connect to database: %v", err)
|
||||
}
|
||||
sqlDb, err := db.DB()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get sql.DB: %v", err)
|
||||
}
|
||||
|
||||
driver, err := sqlite3.WithInstance(sqlDb, &sqlite3.Config{})
|
||||
m, err := migrate.NewWithDatabaseInstance(
|
||||
"file://migrations",
|
||||
"postgres", driver)
|
||||
if err != nil {
|
||||
log.Fatal("failed to create migration instance", err)
|
||||
log.Fatalf("failed to create migration instance: %v", err)
|
||||
}
|
||||
|
||||
err = m.Up()
|
||||
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
log.Fatal("failed to run migrations", err)
|
||||
log.Fatalf("failed to apply migrations: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func connectDatabase() {
|
||||
var database *gorm.DB
|
||||
var err error
|
||||
func connectDatabase() (db *gorm.DB, err error) {
|
||||
dbPath := common.EnvConfig.DBPath
|
||||
|
||||
dbPath := EnvConfig.DBPath
|
||||
if EnvConfig.AppEnv == "test" {
|
||||
// Use in-memory database for testing
|
||||
if common.EnvConfig.AppEnv == "test" {
|
||||
dbPath = "file::memory:?cache=shared"
|
||||
}
|
||||
|
||||
for i := 1; i <= 3; i++ {
|
||||
database, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
db, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||
TranslateError: true,
|
||||
Logger: getLogger(),
|
||||
})
|
||||
@@ -57,11 +60,11 @@ func connectDatabase() {
|
||||
}
|
||||
}
|
||||
|
||||
DB = database
|
||||
return db, err
|
||||
}
|
||||
|
||||
func getLogger() logger.Interface {
|
||||
isProduction := EnvConfig.AppEnv == "production"
|
||||
isProduction := common.EnvConfig.AppEnv == "production"
|
||||
|
||||
var logLevel logger.LogLevel
|
||||
if isProduction {
|
||||
@@ -70,7 +73,6 @@ func getLogger() logger.Interface {
|
||||
logLevel = logger.Info
|
||||
}
|
||||
|
||||
// Create the GORM logger
|
||||
return logger.New(
|
||||
log.New(os.Stdout, "\r\n", log.LstdFlags),
|
||||
logger.Config{
|
||||
66
backend/internal/bootstrap/router_bootstrap.go
Normal file
66
backend/internal/bootstrap/router_bootstrap.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/controller"
|
||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"golang.org/x/time/rate"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func initRouter(db *gorm.DB, appConfigService *service.AppConfigService) {
|
||||
// Set the appropriate Gin mode based on the environment
|
||||
switch common.EnvConfig.AppEnv {
|
||||
case "production":
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
case "development":
|
||||
gin.SetMode(gin.DebugMode)
|
||||
case "test":
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
r := gin.Default()
|
||||
r.Use(gin.Logger())
|
||||
|
||||
// Initialize services
|
||||
webauthnService := service.NewWebAuthnService(db, appConfigService)
|
||||
jwtService := service.NewJwtService(appConfigService)
|
||||
userService := service.NewUserService(db, jwtService)
|
||||
oidcService := service.NewOidcService(db, jwtService)
|
||||
testService := service.NewTestService(db, appConfigService)
|
||||
|
||||
// Add global middleware
|
||||
r.Use(middleware.NewCorsMiddleware().Add())
|
||||
r.Use(middleware.NewRateLimitMiddleware().Add(rate.Every(time.Second), 60))
|
||||
r.Use(middleware.NewJwtAuthMiddleware(jwtService, true).Add(false))
|
||||
|
||||
// Initialize middleware
|
||||
jwtAuthMiddleware := middleware.NewJwtAuthMiddleware(jwtService, false)
|
||||
fileSizeLimitMiddleware := middleware.NewFileSizeLimitMiddleware()
|
||||
|
||||
// Set up API routes
|
||||
apiGroup := r.Group("/api")
|
||||
controller.NewWebauthnController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), webauthnService, jwtService)
|
||||
controller.NewOidcController(apiGroup, jwtAuthMiddleware, fileSizeLimitMiddleware, oidcService, jwtService)
|
||||
controller.NewUserController(apiGroup, jwtAuthMiddleware, middleware.NewRateLimitMiddleware(), userService)
|
||||
controller.NewAppConfigController(apiGroup, jwtAuthMiddleware, appConfigService)
|
||||
|
||||
// Add test controller in non-production environments
|
||||
if common.EnvConfig.AppEnv != "production" {
|
||||
controller.NewTestController(apiGroup, testService)
|
||||
}
|
||||
|
||||
// Set up base routes
|
||||
baseGroup := r.Group("/")
|
||||
controller.NewWellKnownController(baseGroup, jwtService)
|
||||
|
||||
// Run the server
|
||||
if err := r.Run(common.EnvConfig.Host + ":" + common.EnvConfig.Port); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/caarlos0/env/v11"
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"golang-rest-api-template/internal/model"
|
||||
"log"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type EnvConfigSchema struct {
|
||||
AppEnv string `env:"APP_ENV"`
|
||||
AppURL string `env:"PUBLIC_APP_URL"`
|
||||
DBPath string `env:"DB_PATH"`
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
Port string `env:"BACKEND_PORT"`
|
||||
Host string `env:"HOST"`
|
||||
}
|
||||
|
||||
var EnvConfig = &EnvConfigSchema{
|
||||
AppEnv: "production",
|
||||
DBPath: "data/pocket-id.db",
|
||||
UploadPath: "data/uploads",
|
||||
AppURL: "http://localhost",
|
||||
Port: "8080",
|
||||
Host: "localhost",
|
||||
}
|
||||
|
||||
var DbConfig = NewDefaultDbConfig()
|
||||
|
||||
func NewDefaultDbConfig() model.ApplicationConfiguration {
|
||||
return model.ApplicationConfiguration{
|
||||
AppName: model.ApplicationConfigurationVariable{
|
||||
Key: "appName",
|
||||
Type: "string",
|
||||
IsPublic: true,
|
||||
Value: "Pocket ID",
|
||||
},
|
||||
SessionDuration: model.ApplicationConfigurationVariable{
|
||||
Key: "sessionDuration",
|
||||
Type: "number",
|
||||
Value: "60",
|
||||
},
|
||||
BackgroundImageType: model.ApplicationConfigurationVariable{
|
||||
Key: "backgroundImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
Value: "jpg",
|
||||
},
|
||||
LogoImageType: model.ApplicationConfigurationVariable{
|
||||
Key: "logoImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
Value: "svg",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// LoadDbConfigFromDb refreshes the database configuration by loading the current values
|
||||
// from the database and updating the DbConfig struct.
|
||||
func LoadDbConfigFromDb() error {
|
||||
dbConfigReflectValue := reflect.ValueOf(&DbConfig).Elem()
|
||||
|
||||
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
|
||||
dbConfigField := dbConfigReflectValue.Field(i)
|
||||
currentConfigVar := dbConfigField.Interface().(model.ApplicationConfigurationVariable)
|
||||
var storedConfigVar model.ApplicationConfigurationVariable
|
||||
if err := DB.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitDbConfig creates the default configuration values in the database if they do not exist,
|
||||
// updates existing configurations if they differ from the default, and deletes any configurations
|
||||
// that are not in the default configuration.
|
||||
func InitDbConfig() {
|
||||
// Reflect to get the underlying value of DbConfig and its default configuration
|
||||
dbConfigReflectValue := reflect.ValueOf(&DbConfig).Elem()
|
||||
defaultDbConfig := NewDefaultDbConfig()
|
||||
defaultConfigReflectValue := reflect.ValueOf(&defaultDbConfig).Elem()
|
||||
defaultKeys := make(map[string]struct{})
|
||||
|
||||
// Iterate over the fields of DbConfig
|
||||
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
|
||||
dbConfigField := dbConfigReflectValue.Field(i)
|
||||
currentConfigVar := dbConfigField.Interface().(model.ApplicationConfigurationVariable)
|
||||
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.ApplicationConfigurationVariable)
|
||||
defaultKeys[currentConfigVar.Key] = struct{}{}
|
||||
|
||||
var storedConfigVar model.ApplicationConfigurationVariable
|
||||
if err := DB.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
|
||||
// If the configuration does not exist, create it
|
||||
if err := DB.Create(&defaultConfigVar).Error; err != nil {
|
||||
log.Fatalf("Failed to create default configuration: %v", err)
|
||||
}
|
||||
dbConfigField.Set(reflect.ValueOf(defaultConfigVar))
|
||||
continue
|
||||
}
|
||||
|
||||
// Update existing configuration if it differs from the default
|
||||
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal {
|
||||
storedConfigVar.Type = defaultConfigVar.Type
|
||||
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
|
||||
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
|
||||
if err := DB.Save(&storedConfigVar).Error; err != nil {
|
||||
log.Fatalf("Failed to update configuration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Set the value in DbConfig
|
||||
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
|
||||
}
|
||||
|
||||
// Delete any configurations not in the default keys
|
||||
var allConfigVars []model.ApplicationConfigurationVariable
|
||||
if err := DB.Find(&allConfigVars).Error; err != nil {
|
||||
log.Fatalf("Failed to retrieve existing configurations: %v", err)
|
||||
}
|
||||
|
||||
for _, config := range allConfigVars {
|
||||
if _, exists := defaultKeys[config.Key]; !exists {
|
||||
if err := DB.Delete(&config).Error; err != nil {
|
||||
log.Fatalf("Failed to delete outdated configuration: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
31
backend/internal/common/env_config.go
Normal file
31
backend/internal/common/env_config.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/caarlos0/env/v11"
|
||||
_ "github.com/joho/godotenv/autoload"
|
||||
"log"
|
||||
)
|
||||
|
||||
type EnvConfigSchema struct {
|
||||
AppEnv string `env:"APP_ENV"`
|
||||
AppURL string `env:"PUBLIC_APP_URL"`
|
||||
DBPath string `env:"DB_PATH"`
|
||||
UploadPath string `env:"UPLOAD_PATH"`
|
||||
Port string `env:"BACKEND_PORT"`
|
||||
Host string `env:"HOST"`
|
||||
}
|
||||
|
||||
var EnvConfig = &EnvConfigSchema{
|
||||
AppEnv: "production",
|
||||
DBPath: "data/pocket-id.db",
|
||||
UploadPath: "data/uploads",
|
||||
AppURL: "http://localhost",
|
||||
Port: "8080",
|
||||
Host: "localhost",
|
||||
}
|
||||
|
||||
func init() {
|
||||
if err := env.ParseWithOptions(EnvConfig, env.Options{}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
18
backend/internal/common/errors.go
Normal file
18
backend/internal/common/errors.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package common
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrUsernameTaken = errors.New("username is already taken")
|
||||
ErrEmailTaken = errors.New("email is already taken")
|
||||
ErrSetupAlreadyCompleted = errors.New("setup already completed")
|
||||
ErrTokenInvalidOrExpired = errors.New("token is invalid or expired")
|
||||
ErrOidcMissingAuthorization = errors.New("missing authorization")
|
||||
ErrOidcGrantTypeNotSupported = errors.New("grant type not supported")
|
||||
ErrOidcMissingClientCredentials = errors.New("client id or secret not provided")
|
||||
ErrOidcClientSecretInvalid = errors.New("invalid client secret")
|
||||
ErrOidcInvalidAuthorizationCode = errors.New("invalid authorization code")
|
||||
ErrOidcInvalidCallbackURL = errors.New("invalid callback URL")
|
||||
ErrFileTypeNotSupported = errors.New("file type not supported")
|
||||
ErrInvalidCredentials = errors.New("no user found with provided credentials")
|
||||
)
|
||||
@@ -1,209 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"golang-rest-api-template/internal/model"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"log"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
PrivateKey *rsa.PrivateKey
|
||||
PublicKey *rsa.PublicKey
|
||||
)
|
||||
|
||||
const (
|
||||
privateKeyPath = "data/keys/jwt_private_key.pem"
|
||||
publicKeyPath = "data/keys/jwt_public_key.pem"
|
||||
)
|
||||
|
||||
type accessTokenJWTClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
IsAdmin bool `json:"isAdmin,omitempty"`
|
||||
}
|
||||
|
||||
// GenerateIDToken generates an ID token for the given user, clientID, scope and nonce.
|
||||
func GenerateIDToken(user model.User, clientID string, scope string, nonce string) (tokenString string, err error) {
|
||||
profileClaims := map[string]interface{}{
|
||||
"given_name": user.FirstName,
|
||||
"family_name": user.LastName,
|
||||
"email": user.Email,
|
||||
"preferred_username": user.Username,
|
||||
}
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"sub": user.ID,
|
||||
"aud": clientID,
|
||||
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
"iat": jwt.NewNumericDate(time.Now()),
|
||||
}
|
||||
|
||||
if nonce != "" {
|
||||
claims["nonce"] = nonce
|
||||
}
|
||||
if strings.Contains(scope, "profile") {
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
}
|
||||
if strings.Contains(scope, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
signedToken, err := token.SignedString(PrivateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return signedToken, nil
|
||||
}
|
||||
|
||||
// GenerateAccessToken generates an access token for the given user.
|
||||
func GenerateAccessToken(user model.User) (tokenString string, err error) {
|
||||
sessionDurationInMinutes, _ := strconv.Atoi(DbConfig.SessionDuration.Value)
|
||||
claim := accessTokenJWTClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: user.ID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Audience: jwt.ClaimStrings{utils.GetHostFromURL(EnvConfig.AppURL)},
|
||||
},
|
||||
IsAdmin: user.IsAdmin,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||
tokenString, err = token.SignedString(PrivateKey)
|
||||
return tokenString, err
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies the given access token and returns the claims if the token is valid.
|
||||
func VerifyAccessToken(tokenString string) (*accessTokenJWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &accessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return PublicKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
}
|
||||
|
||||
claims, isValid := token.Claims.(*accessTokenJWTClaims)
|
||||
if !isValid {
|
||||
return nil, errors.New("can't parse claims")
|
||||
}
|
||||
|
||||
if !slices.Contains(claims.Audience, utils.GetHostFromURL(EnvConfig.AppURL)) {
|
||||
return nil, errors.New("audience doesn't match")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
type JWK struct {
|
||||
Kty string `json:"kty"`
|
||||
Use string `json:"use"`
|
||||
Kid string `json:"kid"`
|
||||
Alg string `json:"alg"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
}
|
||||
|
||||
// GetJWK returns the JSON Web Key (JWK) for the public key.
|
||||
func GetJWK() (JWK, error) {
|
||||
if PublicKey == nil {
|
||||
return JWK{}, errors.New("public key is not initialized")
|
||||
}
|
||||
|
||||
// Create JWK from RSA public key
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Kid: "1", // Key ID can be set to any identifier. Here it's statically set to "1"
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(PublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(PublicKey.E)).Bytes()),
|
||||
}
|
||||
|
||||
return jwk, nil
|
||||
}
|
||||
|
||||
// generateKeys generates a new RSA key pair and saves the private and public keys to the data folder.
|
||||
func generateKeys() {
|
||||
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
|
||||
log.Fatal("Failed to create directories for keys", err)
|
||||
}
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to generate private key", err)
|
||||
}
|
||||
|
||||
privateKeyFile, err := os.Create(privateKeyPath)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to create private key file", err)
|
||||
}
|
||||
defer privateKeyFile.Close()
|
||||
|
||||
privateKeyPEM := pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
},
|
||||
)
|
||||
_, err = privateKeyFile.Write(privateKeyPEM)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to write private key file", err)
|
||||
}
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
publicKeyFile, err := os.Create(publicKeyPath)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to create public key file", err)
|
||||
}
|
||||
defer publicKeyFile.Close()
|
||||
|
||||
publicKeyPEM := pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: x509.MarshalPKCS1PublicKey(publicKey),
|
||||
},
|
||||
)
|
||||
_, err = publicKeyFile.Write(publicKeyPEM)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to write public key file", err)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
|
||||
generateKeys()
|
||||
}
|
||||
|
||||
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
log.Fatal("Can't read jwt private key", err)
|
||||
}
|
||||
PrivateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
||||
if err != nil {
|
||||
log.Fatal("Can't parse jwt private key", err)
|
||||
}
|
||||
|
||||
publicKeyBytes, err := os.ReadFile(publicKeyPath)
|
||||
if err != nil {
|
||||
log.Fatal("Can't read jwt public key", err)
|
||||
}
|
||||
PublicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||
if err != nil {
|
||||
log.Fatal("Can't parse jwt public key", err)
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func JWTAuth(adminOnly bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
// Extract the token from the cookie or the Authorization header
|
||||
token, err := c.Cookie("access_token")
|
||||
if err != nil {
|
||||
authorizationHeaderSplitted := strings.Split(c.GetHeader("Authorization"), " ")
|
||||
if len(authorizationHeaderSplitted) == 2 {
|
||||
token = authorizationHeaderSplitted[1]
|
||||
} else {
|
||||
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
claims, err := common.VerifyAccessToken(token)
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusUnauthorized, "You're not signed in")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the user is an admin
|
||||
if adminOnly && !claims.IsAdmin {
|
||||
utils.HandlerError(c, http.StatusForbidden, "You don't have permission to access this resource")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("userID", claims.Subject)
|
||||
c.Set("userIsAdmin", claims.IsAdmin)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
WebAuthn *webauthn.WebAuthn
|
||||
err error
|
||||
)
|
||||
|
||||
func init() {
|
||||
config := &webauthn.Config{
|
||||
RPDisplayName: DbConfig.AppName.Value,
|
||||
RPID: utils.GetHostFromURL(EnvConfig.AppURL),
|
||||
RPOrigins: []string{EnvConfig.AppURL},
|
||||
Timeouts: webauthn.TimeoutsConfig{
|
||||
Login: webauthn.TimeoutConfig{
|
||||
Enforce: true,
|
||||
Timeout: time.Second * 60,
|
||||
TimeoutUVD: time.Second * 60,
|
||||
},
|
||||
Registration: webauthn.TimeoutConfig{
|
||||
Enforce: true,
|
||||
Timeout: time.Second * 60,
|
||||
TimeoutUVD: time.Second * 60,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if WebAuthn, err = webauthn.New(config); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
158
backend/internal/controller/app_config_controller.go
Normal file
158
backend/internal/controller/app_config_controller.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func NewAppConfigController(
|
||||
group *gin.RouterGroup,
|
||||
jwtAuthMiddleware *middleware.JwtAuthMiddleware,
|
||||
appConfigService *service.AppConfigService) {
|
||||
|
||||
acc := &AppConfigController{
|
||||
appConfigService: appConfigService,
|
||||
}
|
||||
group.GET("/application-configuration", acc.listApplicationConfigurationHandler)
|
||||
group.GET("/application-configuration/all", jwtAuthMiddleware.Add(true), acc.listAllApplicationConfigurationHandler)
|
||||
group.PUT("/application-configuration", acc.updateApplicationConfigurationHandler)
|
||||
|
||||
group.GET("/application-configuration/logo", acc.getLogoHandler)
|
||||
group.GET("/application-configuration/background-image", acc.getBackgroundImageHandler)
|
||||
group.GET("/application-configuration/favicon", acc.getFaviconHandler)
|
||||
group.PUT("/application-configuration/logo", jwtAuthMiddleware.Add(true), acc.updateLogoHandler)
|
||||
group.PUT("/application-configuration/favicon", jwtAuthMiddleware.Add(true), acc.updateFaviconHandler)
|
||||
group.PUT("/application-configuration/background-image", jwtAuthMiddleware.Add(true), acc.updateBackgroundImageHandler)
|
||||
}
|
||||
|
||||
type AppConfigController struct {
|
||||
appConfigService *service.AppConfigService
|
||||
}
|
||||
|
||||
func (acc *AppConfigController) listApplicationConfigurationHandler(c *gin.Context) {
|
||||
configuration, err := acc.appConfigService.ListApplicationConfiguration(false)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var configVariablesDto []dto.PublicAppConfigVariableDto
|
||||
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, configVariablesDto)
|
||||
}
|
||||
|
||||
func (acc *AppConfigController) listAllApplicationConfigurationHandler(c *gin.Context) {
|
||||
configuration, err := acc.appConfigService.ListApplicationConfiguration(true)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var configVariablesDto []dto.AppConfigVariableDto
|
||||
if err := dto.MapStructList(configuration, &configVariablesDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, configVariablesDto)
|
||||
}
|
||||
|
||||
func (acc *AppConfigController) updateApplicationConfigurationHandler(c *gin.Context) {
|
||||
var input dto.AppConfigUpdateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
savedConfigVariables, err := acc.appConfigService.UpdateApplicationConfiguration(input)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var configVariablesDto []dto.AppConfigVariableDto
|
||||
if err := dto.MapStructList(savedConfigVariables, &configVariablesDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, configVariablesDto)
|
||||
}
|
||||
|
||||
func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
|
||||
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
|
||||
acc.getImage(c, "logo", imageType)
|
||||
}
|
||||
|
||||
func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
|
||||
acc.getImage(c, "favicon", "ico")
|
||||
}
|
||||
|
||||
func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
|
||||
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
|
||||
acc.getImage(c, "background", imageType)
|
||||
}
|
||||
|
||||
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
|
||||
imageType := acc.appConfigService.DbConfig.LogoImageType.Value
|
||||
acc.updateImage(c, "logo", imageType)
|
||||
}
|
||||
|
||||
func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
fileType := utils.GetFileExtension(file.Filename)
|
||||
if fileType != "ico" {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, "File must be of type .ico")
|
||||
return
|
||||
}
|
||||
acc.updateImage(c, "favicon", "ico")
|
||||
}
|
||||
|
||||
func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
|
||||
imageType := acc.appConfigService.DbConfig.BackgroundImageType.Value
|
||||
acc.updateImage(c, "background", imageType)
|
||||
}
|
||||
|
||||
func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) {
|
||||
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
|
||||
mimeType := utils.GetImageMimeType(imageType)
|
||||
|
||||
c.Header("Content-Type", mimeType)
|
||||
c.File(imagePath)
|
||||
}
|
||||
|
||||
func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = acc.appConfigService.UpdateImage(file, imageName, oldImageType)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrFileTypeNotSupported) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
303
backend/internal/controller/oidc_controller.go
Normal file
303
backend/internal/controller/oidc_controller.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func NewOidcController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, fileSizeLimitMiddleware *middleware.FileSizeLimitMiddleware, oidcService *service.OidcService, jwtService *service.JwtService) {
|
||||
oc := &OidcController{oidcService: oidcService, jwtService: jwtService}
|
||||
|
||||
group.POST("/oidc/authorize", jwtAuthMiddleware.Add(false), oc.authorizeHandler)
|
||||
group.POST("/oidc/authorize/new-client", jwtAuthMiddleware.Add(false), oc.authorizeNewClientHandler)
|
||||
group.POST("/oidc/token", oc.createIDTokenHandler)
|
||||
group.GET("/oidc/userinfo", oc.userInfoHandler)
|
||||
|
||||
group.GET("/oidc/clients", jwtAuthMiddleware.Add(true), oc.listClientsHandler)
|
||||
group.POST("/oidc/clients", jwtAuthMiddleware.Add(true), oc.createClientHandler)
|
||||
group.GET("/oidc/clients/:id", oc.getClientHandler)
|
||||
group.PUT("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.updateClientHandler)
|
||||
group.DELETE("/oidc/clients/:id", jwtAuthMiddleware.Add(true), oc.deleteClientHandler)
|
||||
|
||||
group.POST("/oidc/clients/:id/secret", jwtAuthMiddleware.Add(true), oc.createClientSecretHandler)
|
||||
|
||||
group.GET("/oidc/clients/:id/logo", oc.getClientLogoHandler)
|
||||
group.DELETE("/oidc/clients/:id/logo", oc.deleteClientLogoHandler)
|
||||
group.POST("/oidc/clients/:id/logo", jwtAuthMiddleware.Add(true), fileSizeLimitMiddleware.Add(2<<20), oc.updateClientLogoHandler)
|
||||
}
|
||||
|
||||
type OidcController struct {
|
||||
oidcService *service.OidcService
|
||||
jwtService *service.JwtService
|
||||
}
|
||||
|
||||
func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||
var input dto.AuthorizeOidcClientRequestDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
code, callbackURL, err := oc.oidcService.Authorize(input, c.GetString("userID"))
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrOidcMissingAuthorization) {
|
||||
utils.CustomControllerError(c, http.StatusForbidden, err.Error())
|
||||
} else if errors.Is(err, common.ErrOidcInvalidCallbackURL) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
response := dto.AuthorizeOidcClientResponseDto{
|
||||
Code: code,
|
||||
CallbackURL: callbackURL,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
|
||||
var input dto.AuthorizeOidcClientRequestDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
code, callbackURL, err := oc.oidcService.AuthorizeNewClient(input, c.GetString("userID"))
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrOidcInvalidCallbackURL) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
response := dto.AuthorizeOidcClientResponseDto{
|
||||
Code: code,
|
||||
CallbackURL: callbackURL,
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (oc *OidcController) createIDTokenHandler(c *gin.Context) {
|
||||
var input dto.OidcIdTokenDto
|
||||
|
||||
if err := c.ShouldBind(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
clientID := input.ClientID
|
||||
clientSecret := input.ClientSecret
|
||||
|
||||
// Client id and secret can also be passed over the Authorization header
|
||||
if clientID == "" || clientSecret == "" {
|
||||
var ok bool
|
||||
clientID, clientSecret, ok = c.Request.BasicAuth()
|
||||
if !ok {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, "Client id and secret not provided")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrOidcGrantTypeNotSupported) ||
|
||||
errors.Is(err, common.ErrOidcMissingClientCredentials) ||
|
||||
errors.Is(err, common.ErrOidcClientSecretInvalid) ||
|
||||
errors.Is(err, common.ErrOidcInvalidAuthorizationCode) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"id_token": idToken, "access_token": accessToken, "token_type": "Bearer"})
|
||||
}
|
||||
|
||||
func (oc *OidcController) userInfoHandler(c *gin.Context) {
|
||||
token := strings.Split(c.GetHeader("Authorization"), " ")[1]
|
||||
jwtClaims, err := oc.jwtService.VerifyOauthAccessToken(token)
|
||||
if err != nil {
|
||||
utils.CustomControllerError(c, http.StatusUnauthorized, common.ErrTokenInvalidOrExpired.Error())
|
||||
return
|
||||
}
|
||||
userID := jwtClaims.Subject
|
||||
clientId := jwtClaims.Audience[0]
|
||||
claims, err := oc.oidcService.GetUserClaimsForClient(userID, clientId)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, claims)
|
||||
}
|
||||
|
||||
func (oc *OidcController) getClientHandler(c *gin.Context) {
|
||||
clientId := c.Param("id")
|
||||
client, err := oc.oidcService.GetClient(clientId)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Return a different DTO based on the user's role
|
||||
if c.GetBool("userIsAdmin") {
|
||||
clientDto := dto.OidcClientDto{}
|
||||
err = dto.MapStruct(client, &clientDto)
|
||||
if err == nil {
|
||||
c.JSON(http.StatusOK, clientDto)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
clientDto := dto.PublicOidcClientDto{}
|
||||
err = dto.MapStruct(client, &clientDto)
|
||||
if err == nil {
|
||||
c.JSON(http.StatusOK, clientDto)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
|
||||
func (oc *OidcController) listClientsHandler(c *gin.Context) {
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
|
||||
searchTerm := c.Query("search")
|
||||
|
||||
clients, pagination, err := oc.oidcService.ListClients(searchTerm, page, pageSize)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var clientsDto []dto.OidcClientDto
|
||||
if err := dto.MapStructList(clients, &clientsDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": clientsDto,
|
||||
"pagination": pagination,
|
||||
})
|
||||
}
|
||||
|
||||
func (oc *OidcController) createClientHandler(c *gin.Context) {
|
||||
var input dto.OidcClientCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
client, err := oc.oidcService.CreateClient(input, c.GetString("userID"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var clientDto dto.OidcClientDto
|
||||
if err := dto.MapStruct(client, &clientDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, clientDto)
|
||||
}
|
||||
|
||||
func (oc *OidcController) deleteClientHandler(c *gin.Context) {
|
||||
err := oc.oidcService.DeleteClient(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (oc *OidcController) updateClientHandler(c *gin.Context) {
|
||||
var input dto.OidcClientCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
client, err := oc.oidcService.UpdateClient(c.Param("id"), input)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var clientDto dto.OidcClientDto
|
||||
if err := dto.MapStruct(client, &clientDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, clientDto)
|
||||
}
|
||||
|
||||
func (oc *OidcController) createClientSecretHandler(c *gin.Context) {
|
||||
secret, err := oc.oidcService.CreateClientSecret(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"secret": secret})
|
||||
}
|
||||
|
||||
func (oc *OidcController) getClientLogoHandler(c *gin.Context) {
|
||||
imagePath, mimeType, err := oc.oidcService.GetClientLogo(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Type", mimeType)
|
||||
c.File(imagePath)
|
||||
}
|
||||
|
||||
func (oc *OidcController) updateClientLogoHandler(c *gin.Context) {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
err = oc.oidcService.UpdateClientLogo(c.Param("id"), file)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrFileTypeNotSupported) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (oc *OidcController) deleteClientLogoHandler(c *gin.Context) {
|
||||
err := oc.oidcService.DeleteClientLogo(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
37
backend/internal/controller/test_controller.go
Normal file
37
backend/internal/controller/test_controller.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func NewTestController(group *gin.RouterGroup, testService *service.TestService) {
|
||||
testController := &TestController{TestService: testService}
|
||||
|
||||
group.POST("/test/reset", testController.resetAndSeedHandler)
|
||||
}
|
||||
|
||||
type TestController struct {
|
||||
TestService *service.TestService
|
||||
}
|
||||
|
||||
func (tc *TestController) resetAndSeedHandler(c *gin.Context) {
|
||||
if err := tc.TestService.ResetDatabase(); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.ResetApplicationImages(); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := tc.TestService.SeedDatabase(); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
220
backend/internal/controller/user_controller.go
Normal file
220
backend/internal/controller/user_controller.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"golang.org/x/time/rate"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewUserController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, userService *service.UserService) {
|
||||
uc := UserController{
|
||||
UserService: userService,
|
||||
}
|
||||
|
||||
group.GET("/users", jwtAuthMiddleware.Add(true), uc.listUsersHandler)
|
||||
group.GET("/users/me", jwtAuthMiddleware.Add(false), uc.getCurrentUserHandler)
|
||||
group.GET("/users/:id", jwtAuthMiddleware.Add(true), uc.getUserHandler)
|
||||
group.POST("/users", jwtAuthMiddleware.Add(true), uc.createUserHandler)
|
||||
group.PUT("/users/:id", jwtAuthMiddleware.Add(true), uc.updateUserHandler)
|
||||
group.PUT("/users/me", jwtAuthMiddleware.Add(false), uc.updateCurrentUserHandler)
|
||||
group.DELETE("/users/:id", jwtAuthMiddleware.Add(true), uc.deleteUserHandler)
|
||||
|
||||
group.POST("/users/:id/one-time-access-token", jwtAuthMiddleware.Add(true), uc.createOneTimeAccessTokenHandler)
|
||||
group.POST("/one-time-access-token/:token", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), uc.exchangeOneTimeAccessTokenHandler)
|
||||
group.POST("/one-time-access-token/setup", uc.getSetupAccessTokenHandler)
|
||||
}
|
||||
|
||||
type UserController struct {
|
||||
UserService *service.UserService
|
||||
}
|
||||
|
||||
func (uc *UserController) listUsersHandler(c *gin.Context) {
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
|
||||
searchTerm := c.Query("search")
|
||||
|
||||
users, pagination, err := uc.UserService.ListUsers(searchTerm, page, pageSize)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var usersDto []dto.UserDto
|
||||
if err := dto.MapStructList(users, &usersDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": usersDto,
|
||||
"pagination": pagination,
|
||||
})
|
||||
}
|
||||
|
||||
func (uc *UserController) getUserHandler(c *gin.Context) {
|
||||
user, err := uc.UserService.GetUser(c.Param("id"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, userDto)
|
||||
}
|
||||
|
||||
func (uc *UserController) getCurrentUserHandler(c *gin.Context) {
|
||||
user, err := uc.UserService.GetUser(c.GetString("userID"))
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, userDto)
|
||||
}
|
||||
|
||||
func (uc *UserController) deleteUserHandler(c *gin.Context) {
|
||||
if err := uc.UserService.DeleteUser(c.Param("id")); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (uc *UserController) createUserHandler(c *gin.Context) {
|
||||
var input dto.UserCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := uc.UserService.CreateUser(input)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
|
||||
utils.CustomControllerError(c, http.StatusConflict, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, userDto)
|
||||
}
|
||||
|
||||
func (uc *UserController) updateUserHandler(c *gin.Context) {
|
||||
uc.updateUser(c, false)
|
||||
}
|
||||
|
||||
func (uc *UserController) updateCurrentUserHandler(c *gin.Context) {
|
||||
uc.updateUser(c, true)
|
||||
}
|
||||
|
||||
func (uc *UserController) createOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
var input dto.OneTimeAccessTokenCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := uc.UserService.CreateOneTimeAccessToken(input.UserID, input.ExpiresAt)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{"token": token})
|
||||
}
|
||||
|
||||
func (uc *UserController) exchangeOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
user, token, err := uc.UserService.ExchangeOneTimeAccessToken(c.Param("token"))
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrTokenInvalidOrExpired) {
|
||||
utils.CustomControllerError(c, http.StatusUnauthorized, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
func (uc *UserController) getSetupAccessTokenHandler(c *gin.Context) {
|
||||
user, token, err := uc.UserService.SetupInitialAdmin()
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrSetupAlreadyCompleted) {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
||||
c.JSON(http.StatusOK, userDto)
|
||||
}
|
||||
|
||||
func (uc *UserController) updateUser(c *gin.Context, updateOwnUser bool) {
|
||||
var input dto.UserCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var userID string
|
||||
if updateOwnUser {
|
||||
userID = c.GetString("userID")
|
||||
} else {
|
||||
userID = c.Param("id")
|
||||
}
|
||||
|
||||
user, err := uc.UserService.UpdateUser(userID, input, updateOwnUser)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrEmailTaken) || errors.Is(err, common.ErrUsernameTaken) {
|
||||
utils.CustomControllerError(c, http.StatusConflict, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, userDto)
|
||||
}
|
||||
182
backend/internal/controller/webauthn_controller.go
Normal file
182
backend/internal/controller/webauthn_controller.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/middleware"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
func NewWebauthnController(group *gin.RouterGroup, jwtAuthMiddleware *middleware.JwtAuthMiddleware, rateLimitMiddleware *middleware.RateLimitMiddleware, webauthnService *service.WebAuthnService, jwtService *service.JwtService) {
|
||||
wc := &WebauthnController{webAuthnService: webauthnService, jwtService: jwtService}
|
||||
group.GET("/webauthn/register/start", jwtAuthMiddleware.Add(false), wc.beginRegistrationHandler)
|
||||
group.POST("/webauthn/register/finish", jwtAuthMiddleware.Add(false), wc.verifyRegistrationHandler)
|
||||
|
||||
group.GET("/webauthn/login/start", wc.beginLoginHandler)
|
||||
group.POST("/webauthn/login/finish", rateLimitMiddleware.Add(rate.Every(10*time.Second), 5), wc.verifyLoginHandler)
|
||||
|
||||
group.POST("/webauthn/logout", jwtAuthMiddleware.Add(false), wc.logoutHandler)
|
||||
|
||||
group.GET("/webauthn/credentials", jwtAuthMiddleware.Add(false), wc.listCredentialsHandler)
|
||||
group.PATCH("/webauthn/credentials/:id", jwtAuthMiddleware.Add(false), wc.updateCredentialHandler)
|
||||
group.DELETE("/webauthn/credentials/:id", jwtAuthMiddleware.Add(false), wc.deleteCredentialHandler)
|
||||
}
|
||||
|
||||
type WebauthnController struct {
|
||||
webAuthnService *service.WebAuthnService
|
||||
jwtService *service.JwtService
|
||||
}
|
||||
|
||||
func (wc *WebauthnController) beginRegistrationHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
options, err := wc.webAuthnService.BeginRegistration(userID)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie("session_id", options.SessionID, int(options.Timeout.Seconds()), "/", "", false, true)
|
||||
c.JSON(http.StatusOK, options.Response)
|
||||
}
|
||||
|
||||
func (wc *WebauthnController) verifyRegistrationHandler(c *gin.Context) {
|
||||
sessionID, err := c.Cookie("session_id")
|
||||
if err != nil {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing")
|
||||
return
|
||||
}
|
||||
|
||||
userID := c.GetString("userID")
|
||||
credential, err := wc.webAuthnService.VerifyRegistration(sessionID, userID, c.Request)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var credentialDto dto.WebauthnCredentialDto
|
||||
if err := dto.MapStruct(credential, &credentialDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, credentialDto)
|
||||
}
|
||||
|
||||
func (wc *WebauthnController) beginLoginHandler(c *gin.Context) {
|
||||
options, err := wc.webAuthnService.BeginLogin()
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie("session_id", options.SessionID, int(options.Timeout.Seconds()), "/", "", false, true)
|
||||
c.JSON(http.StatusOK, options.Response)
|
||||
}
|
||||
|
||||
func (wc *WebauthnController) verifyLoginHandler(c *gin.Context) {
|
||||
sessionID, err := c.Cookie("session_id")
|
||||
if err != nil {
|
||||
utils.CustomControllerError(c, http.StatusBadRequest, "Session ID missing")
|
||||
return
|
||||
}
|
||||
|
||||
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
userID := c.GetString("userID")
|
||||
user, err := wc.webAuthnService.VerifyLogin(sessionID, userID, credentialAssertionData)
|
||||
if err != nil {
|
||||
if errors.Is(err, common.ErrInvalidCredentials) {
|
||||
utils.CustomControllerError(c, http.StatusUnauthorized, err.Error())
|
||||
} else {
|
||||
utils.ControllerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
token, err := wc.jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var userDto dto.UserDto
|
||||
if err := dto.MapStruct(user, &userDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
||||
c.JSON(http.StatusOK, userDto)
|
||||
}
|
||||
|
||||
func (wc *WebauthnController) listCredentialsHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
credentials, err := wc.webAuthnService.ListCredentials(userID)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var credentialDtos []dto.WebauthnCredentialDto
|
||||
if err := dto.MapStructList(credentials, &credentialDtos); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, credentialDtos)
|
||||
}
|
||||
|
||||
func (wc *WebauthnController) deleteCredentialHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
credentialID := c.Param("id")
|
||||
|
||||
err := wc.webAuthnService.DeleteCredential(userID, credentialID)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (wc *WebauthnController) updateCredentialHandler(c *gin.Context) {
|
||||
userID := c.GetString("userID")
|
||||
credentialID := c.Param("id")
|
||||
|
||||
var input dto.WebauthnCredentialUpdateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
credential, err := wc.webAuthnService.UpdateCredential(userID, credentialID, input.Name)
|
||||
if err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var credentialDto dto.WebauthnCredentialDto
|
||||
if err := dto.MapStruct(credential, &credentialDto); err != nil {
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, credentialDto)
|
||||
}
|
||||
|
||||
func (wc *WebauthnController) logoutHandler(c *gin.Context) {
|
||||
c.SetCookie("access_token", "", 0, "/", "", false, true)
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
@@ -1,33 +1,40 @@
|
||||
package handler
|
||||
package controller
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func RegisterWellKnownRoutes(group *gin.RouterGroup) {
|
||||
group.GET("/.well-known/jwks.json", jwks)
|
||||
group.GET("/.well-known/openid-configuration", openIDConfiguration)
|
||||
func NewWellKnownController(group *gin.RouterGroup, jwtService *service.JwtService) {
|
||||
wkc := &WellKnownController{jwtService: jwtService}
|
||||
group.GET("/.well-known/jwks.json", wkc.jwksHandler)
|
||||
group.GET("/.well-known/openid-configuration", wkc.openIDConfigurationHandler)
|
||||
}
|
||||
|
||||
func jwks(c *gin.Context) {
|
||||
jwk, err := common.GetJWK()
|
||||
type WellKnownController struct {
|
||||
jwtService *service.JwtService
|
||||
}
|
||||
|
||||
func (wkc *WellKnownController) jwksHandler(c *gin.Context) {
|
||||
jwk, err := wkc.jwtService.GetJWK()
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
utils.ControllerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"keys": []interface{}{jwk}})
|
||||
}
|
||||
|
||||
func openIDConfiguration(c *gin.Context) {
|
||||
func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) {
|
||||
appUrl := common.EnvConfig.AppURL
|
||||
config := map[string]interface{}{
|
||||
"issuer": appUrl,
|
||||
"authorization_endpoint": appUrl + "/authorize",
|
||||
"token_endpoint": appUrl + "/api/oidc/token",
|
||||
"userinfo_endpoint": appUrl + "/api/oidc/userinfo",
|
||||
"jwks_uri": appUrl + "/.well-known/jwks.json",
|
||||
"scopes_supported": []string{"openid", "profile", "email"},
|
||||
"claims_supported": []string{"sub", "given_name", "family_name", "email", "preferred_username"},
|
||||
17
backend/internal/dto/app_config_dto.go
Normal file
17
backend/internal/dto/app_config_dto.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package dto
|
||||
|
||||
type PublicAppConfigVariableDto struct {
|
||||
Key string `json:"key"`
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type AppConfigVariableDto struct {
|
||||
PublicAppConfigVariableDto
|
||||
IsPublic bool `json:"isPublic"`
|
||||
}
|
||||
|
||||
type AppConfigUpdateDto struct {
|
||||
AppName string `json:"appName" binding:"required,min=1,max=30"`
|
||||
SessionDuration string `json:"sessionDuration" binding:"required"`
|
||||
}
|
||||
81
backend/internal/dto/dto_mapper.go
Normal file
81
backend/internal/dto/dto_mapper.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// MapStructList maps a list of source structs to a list of destination structs
|
||||
func MapStructList[S any, D any](source []S, destination *[]D) error {
|
||||
*destination = make([]D, 0, len(source))
|
||||
|
||||
for _, item := range source {
|
||||
var destItem D
|
||||
if err := MapStruct(item, &destItem); err != nil {
|
||||
return err
|
||||
}
|
||||
*destination = append(*destination, destItem)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MapStruct maps a source struct to a destination struct
|
||||
func MapStruct[S any, D any](source S, destination *D) error {
|
||||
// Ensure destination is a non-nil pointer
|
||||
destValue := reflect.ValueOf(destination)
|
||||
if destValue.Kind() != reflect.Ptr || destValue.IsNil() {
|
||||
return errors.New("destination must be a non-nil pointer to a struct")
|
||||
}
|
||||
|
||||
// Ensure source is a struct
|
||||
sourceValue := reflect.ValueOf(source)
|
||||
if sourceValue.Kind() != reflect.Struct {
|
||||
return errors.New("source must be a struct")
|
||||
}
|
||||
|
||||
return mapStructInternal(sourceValue, destValue.Elem())
|
||||
}
|
||||
|
||||
func mapStructInternal(sourceVal reflect.Value, destVal reflect.Value) error {
|
||||
// Loop through the fields of the destination struct
|
||||
for i := 0; i < destVal.NumField(); i++ {
|
||||
destField := destVal.Field(i)
|
||||
destFieldType := destVal.Type().Field(i)
|
||||
|
||||
if destFieldType.Anonymous {
|
||||
// Recursively handle embedded structs
|
||||
if err := mapStructInternal(sourceVal, destField); err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
sourceField := sourceVal.FieldByName(destFieldType.Name)
|
||||
|
||||
// If the source field is valid and can be assigned to the destination field
|
||||
if sourceField.IsValid() && destField.CanSet() {
|
||||
// Handle direct assignment for simple types
|
||||
if sourceField.Type() == destField.Type() {
|
||||
destField.Set(sourceField)
|
||||
} else if sourceField.Kind() == reflect.Slice && destField.Kind() == reflect.Slice {
|
||||
// Handle slices
|
||||
if sourceField.Type().Elem() == destField.Type().Elem() {
|
||||
newSlice := reflect.MakeSlice(destField.Type(), sourceField.Len(), sourceField.Cap())
|
||||
|
||||
for j := 0; j < sourceField.Len(); j++ {
|
||||
newSlice.Index(j).Set(sourceField.Index(j))
|
||||
}
|
||||
|
||||
destField.Set(newSlice)
|
||||
}
|
||||
} else if sourceField.Kind() == reflect.Struct && destField.Kind() == reflect.Struct {
|
||||
// Recursively map nested structs
|
||||
if err := mapStructInternal(sourceField, destField); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
37
backend/internal/dto/oidc_dto.go
Normal file
37
backend/internal/dto/oidc_dto.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package dto
|
||||
|
||||
type PublicOidcClientDto struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type OidcClientDto struct {
|
||||
PublicOidcClientDto
|
||||
HasLogo bool `json:"hasLogo"`
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
CreatedBy UserDto `json:"createdBy"`
|
||||
}
|
||||
|
||||
type OidcClientCreateDto struct {
|
||||
Name string `json:"name" binding:"required,max=50"`
|
||||
CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"`
|
||||
}
|
||||
|
||||
type AuthorizeOidcClientRequestDto struct {
|
||||
ClientID string `json:"clientID" binding:"required"`
|
||||
Scope string `json:"scope" binding:"required"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
type AuthorizeOidcClientResponseDto struct {
|
||||
Code string `json:"code"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
}
|
||||
|
||||
type OidcIdTokenDto struct {
|
||||
GrantType string `form:"grant_type" binding:"required"`
|
||||
Code string `form:"code" binding:"required"`
|
||||
ClientID string `form:"client_id"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
}
|
||||
30
backend/internal/dto/user_dto.go
Normal file
30
backend/internal/dto/user_dto.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package dto
|
||||
|
||||
import "time"
|
||||
|
||||
type UserDto struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email" `
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
}
|
||||
|
||||
type UserCreateDto struct {
|
||||
Username string `json:"username" binding:"required,username,min=3,max=20"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
FirstName string `json:"firstName" binding:"required,min=3,max=30"`
|
||||
LastName string `json:"lastName" binding:"required,min=3,max=30"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
}
|
||||
|
||||
type OneTimeAccessTokenCreateDto struct {
|
||||
UserID string `json:"userId" binding:"required"`
|
||||
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
|
||||
}
|
||||
|
||||
type LoginUserDto struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
39
backend/internal/dto/validations.go
Normal file
39
backend/internal/dto/validations.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"log"
|
||||
"net/url"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var validateUrlList validator.Func = func(fl validator.FieldLevel) bool {
|
||||
urls := fl.Field().Interface().([]string)
|
||||
for _, u := range urls {
|
||||
_, err := url.ParseRequestURI(u)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var validateUsername validator.Func = func(fl validator.FieldLevel) bool {
|
||||
regex := "^[a-z0-9_]*$"
|
||||
matched, _ := regexp.MatchString(regex, fl.Field().String())
|
||||
return matched
|
||||
}
|
||||
|
||||
func init() {
|
||||
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
||||
if err := v.RegisterValidation("urlList", validateUrlList); err != nil {
|
||||
log.Fatalf("Failed to register custom validation: %v", err)
|
||||
}
|
||||
}
|
||||
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
||||
if err := v.RegisterValidation("username", validateUsername); err != nil {
|
||||
log.Fatalf("Failed to register custom validation: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
23
backend/internal/dto/webauthn_dto.go
Normal file
23
backend/internal/dto/webauthn_dto.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WebauthnCredentialDto struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
CredentialID string `json:"credentialID"`
|
||||
AttestationType string `json:"attestationType"`
|
||||
Transport []protocol.AuthenticatorTransport `json:"transport"`
|
||||
|
||||
BackupEligible bool `json:"backupEligible"`
|
||||
BackupState bool `json:"backupState"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
type WebauthnCredentialUpdateDto struct {
|
||||
Name string `json:"name" binding:"required,min=1,max=30"`
|
||||
}
|
||||
@@ -1,196 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/common/middleware"
|
||||
"golang-rest-api-template/internal/model"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func RegisterConfigurationRoutes(group *gin.RouterGroup) {
|
||||
group.GET("/application-configuration", listApplicationConfigurationHandler)
|
||||
group.GET("/application-configuration/all", middleware.JWTAuth(true), listAllApplicationConfigurationHandler)
|
||||
group.PUT("/application-configuration", updateApplicationConfigurationHandler)
|
||||
|
||||
group.GET("/application-configuration/logo", getLogoHandler)
|
||||
group.GET("/application-configuration/background-image", getBackgroundImageHandler)
|
||||
group.GET("/application-configuration/favicon", getFaviconHandler)
|
||||
group.PUT("/application-configuration/logo", middleware.JWTAuth(true), updateLogoHandler)
|
||||
group.PUT("/application-configuration/favicon", middleware.JWTAuth(true), updateFaviconHandler)
|
||||
group.PUT("/application-configuration/background-image", middleware.JWTAuth(true), updateBackgroundImageHandler)
|
||||
}
|
||||
|
||||
func listApplicationConfigurationHandler(c *gin.Context) {
|
||||
listApplicationConfiguration(c, false)
|
||||
}
|
||||
|
||||
func listAllApplicationConfigurationHandler(c *gin.Context) {
|
||||
listApplicationConfiguration(c, true)
|
||||
}
|
||||
|
||||
func updateApplicationConfigurationHandler(c *gin.Context) {
|
||||
var input model.ApplicationConfigurationUpdateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
savedConfigVariables := make([]model.ApplicationConfigurationVariable, 10)
|
||||
|
||||
tx := common.DB.Begin()
|
||||
rt := reflect.ValueOf(input).Type()
|
||||
rv := reflect.ValueOf(input)
|
||||
|
||||
// Loop over the input struct fields and update the related configuration variables
|
||||
for i := 0; i < rt.NumField(); i++ {
|
||||
field := rt.Field(i)
|
||||
key := field.Tag.Get("json")
|
||||
value := rv.FieldByName(field.Name).String()
|
||||
|
||||
// Get the existing configuration variable from the db
|
||||
var applicationConfigurationVariable model.ApplicationConfigurationVariable
|
||||
if err := tx.First(&applicationConfigurationVariable, "key = ? AND is_internal = false", key).Error; err != nil {
|
||||
tx.Rollback()
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
utils.HandlerError(c, http.StatusNotFound, fmt.Sprintf("Invalid configuration variable '%s'", value))
|
||||
} else {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Update the value of the existing configuration variable and save it
|
||||
applicationConfigurationVariable.Value = value
|
||||
if err := tx.Save(&applicationConfigurationVariable).Error; err != nil {
|
||||
tx.Rollback()
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
savedConfigVariables[i] = applicationConfigurationVariable
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
if err := common.LoadDbConfigFromDb(); err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, savedConfigVariables)
|
||||
|
||||
}
|
||||
|
||||
func getLogoHandler(c *gin.Context) {
|
||||
imagType := common.DbConfig.LogoImageType.Value
|
||||
getImage(c, "logo", imagType)
|
||||
}
|
||||
|
||||
func getFaviconHandler(c *gin.Context) {
|
||||
getImage(c, "favicon", "ico")
|
||||
}
|
||||
|
||||
func getBackgroundImageHandler(c *gin.Context) {
|
||||
imageType := common.DbConfig.BackgroundImageType.Value
|
||||
getImage(c, "background", imageType)
|
||||
}
|
||||
|
||||
func updateLogoHandler(c *gin.Context) {
|
||||
imageType := common.DbConfig.LogoImageType.Value
|
||||
updateImage(c, "logo", imageType)
|
||||
}
|
||||
|
||||
func updateFaviconHandler(c *gin.Context) {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
fileType := utils.GetFileExtension(file.Filename)
|
||||
if fileType != "ico" {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "File must be of type .ico")
|
||||
return
|
||||
}
|
||||
updateImage(c, "favicon", "ico")
|
||||
}
|
||||
|
||||
func updateBackgroundImageHandler(c *gin.Context) {
|
||||
imagType := common.DbConfig.BackgroundImageType.Value
|
||||
updateImage(c, "background", imagType)
|
||||
}
|
||||
|
||||
func getImage(c *gin.Context, name string, imageType string) {
|
||||
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, name, imageType)
|
||||
mimeType := utils.GetImageMimeType(imageType)
|
||||
|
||||
c.Header("Content-Type", mimeType)
|
||||
c.File(imagePath)
|
||||
}
|
||||
|
||||
func updateImage(c *gin.Context, imageName string, oldImageType string) {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
fileType := utils.GetFileExtension(file.Filename)
|
||||
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "File type not supported")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the old image if it has a different file type
|
||||
if fileType != oldImageType {
|
||||
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType)
|
||||
if err := os.Remove(oldImagePath); err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType)
|
||||
err = c.SaveUploadedFile(file, imagePath)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Update the file type in the database
|
||||
key := fmt.Sprintf("%sImageType", imageName)
|
||||
err = common.DB.Model(&model.ApplicationConfigurationVariable{}).Where("key = ?", key).Update("value", fileType).Error
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := common.LoadDbConfigFromDb(); err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func listApplicationConfiguration(c *gin.Context, showAll bool) {
|
||||
var configuration []model.ApplicationConfigurationVariable
|
||||
var err error
|
||||
|
||||
if showAll {
|
||||
err = common.DB.Find(&configuration).Error
|
||||
} else {
|
||||
err = common.DB.Find(&configuration, "is_public = true").Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, configuration)
|
||||
}
|
||||
@@ -1,415 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/common/middleware"
|
||||
"golang-rest-api-template/internal/model"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func RegisterOIDCRoutes(group *gin.RouterGroup) {
|
||||
group.POST("/oidc/authorize", middleware.JWTAuth(false), authorizeHandler)
|
||||
group.POST("/oidc/authorize/new-client", middleware.JWTAuth(false), authorizeNewClientHandler)
|
||||
group.POST("/oidc/token", createIDTokenHandler)
|
||||
|
||||
group.GET("/oidc/clients", middleware.JWTAuth(true), listClientsHandler)
|
||||
group.POST("/oidc/clients", middleware.JWTAuth(true), createClientHandler)
|
||||
group.GET("/oidc/clients/:id", getClientHandler)
|
||||
group.PUT("/oidc/clients/:id", middleware.JWTAuth(true), updateClientHandler)
|
||||
group.DELETE("/oidc/clients/:id", middleware.JWTAuth(true), deleteClientHandler)
|
||||
|
||||
group.POST("/oidc/clients/:id/secret", middleware.JWTAuth(true), createClientSecretHandler)
|
||||
|
||||
group.GET("/oidc/clients/:id/logo", getClientLogoHandler)
|
||||
group.DELETE("/oidc/clients/:id/logo", deleteClientLogoHandler)
|
||||
group.POST("/oidc/clients/:id/logo", middleware.JWTAuth(true), middleware.LimitFileSize(2<<20), updateClientLogoHandler)
|
||||
}
|
||||
|
||||
type AuthorizeRequest struct {
|
||||
ClientID string `json:"clientID" binding:"required"`
|
||||
Scope string `json:"scope" binding:"required"`
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
func authorizeHandler(c *gin.Context) {
|
||||
var parsedBody AuthorizeRequest
|
||||
if err := c.ShouldBindJSON(&parsedBody); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
||||
common.DB.First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", parsedBody.ClientID, c.GetString("userID"))
|
||||
|
||||
// If the record isn't found or the scope is different return an error
|
||||
// The client will have to call the authorizeNewClientHandler
|
||||
if userAuthorizedOIDCClient.Scope != parsedBody.Scope {
|
||||
utils.HandlerError(c, http.StatusForbidden, "missing authorization")
|
||||
return
|
||||
}
|
||||
|
||||
authorizationCode, err := createAuthorizationCode(parsedBody.ClientID, c.GetString("userID"), parsedBody.Scope, parsedBody.Nonce)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"code": authorizationCode})
|
||||
}
|
||||
|
||||
// authorizeNewClientHandler authorizes a new client for the user
|
||||
// a new client is a new client when the user has not authorized the client before
|
||||
func authorizeNewClientHandler(c *gin.Context) {
|
||||
var parsedBody model.AuthorizeNewClientDto
|
||||
if err := c.ShouldBindJSON(&parsedBody); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
UserID: c.GetString("userID"),
|
||||
ClientID: parsedBody.ClientID,
|
||||
Scope: parsedBody.Scope,
|
||||
}
|
||||
err := common.DB.Create(&userAuthorizedClient).Error
|
||||
|
||||
if err != nil && errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
err = common.DB.Model(&userAuthorizedClient).Update("scope", parsedBody.Scope).Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
authorizationCode, err := createAuthorizationCode(parsedBody.ClientID, c.GetString("userID"), parsedBody.Scope, parsedBody.Nonce)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"code": authorizationCode})
|
||||
|
||||
}
|
||||
|
||||
func createIDTokenHandler(c *gin.Context) {
|
||||
var body model.OidcIdTokenDto
|
||||
|
||||
if err := c.ShouldBind(&body); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Currently only authorization_code grant type is supported
|
||||
if body.GrantType != "authorization_code" {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "grant type not supported")
|
||||
return
|
||||
}
|
||||
|
||||
clientID := body.ClientID
|
||||
clientSecret := body.ClientSecret
|
||||
|
||||
// Client id and secret can also be passed over the Authorization header
|
||||
if clientID == "" || clientSecret == "" {
|
||||
var ok bool
|
||||
clientID, clientSecret, ok = c.Request.BasicAuth()
|
||||
if !ok {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "Client id and secret not provided")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Get the client
|
||||
var client model.OidcClient
|
||||
err := common.DB.First(&client, "id = ?", clientID, clientSecret).Error
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "OIDC OIDC client not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if client secret is correct
|
||||
err = bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid client secret")
|
||||
return
|
||||
}
|
||||
|
||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||
err = common.DB.Preload("User").First(&authorizationCodeMetaData, "code = ?", body.Code).Error
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the client id matches the client id in the authorization code and if the code has expired
|
||||
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.Before(time.Now()) {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid authorization code")
|
||||
return
|
||||
}
|
||||
|
||||
idToken, e := common.GenerateIDToken(authorizationCodeMetaData.User, clientID, authorizationCodeMetaData.Scope, authorizationCodeMetaData.Nonce)
|
||||
if e != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the authorization code after it has been used
|
||||
common.DB.Delete(&authorizationCodeMetaData)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"id_token": idToken})
|
||||
}
|
||||
|
||||
func getClientHandler(c *gin.Context) {
|
||||
clientId := c.Param("id")
|
||||
|
||||
var client model.OidcClient
|
||||
err := common.DB.First(&client, "id = ?", clientId).Error
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, client)
|
||||
}
|
||||
|
||||
func listClientsHandler(c *gin.Context) {
|
||||
var clients []model.OidcClient
|
||||
searchTerm := c.Query("search")
|
||||
|
||||
query := common.DB.Model(&model.OidcClient{})
|
||||
|
||||
if searchTerm != "" {
|
||||
searchPattern := "%" + searchTerm + "%"
|
||||
query = query.Where("name LIKE ?", searchPattern)
|
||||
}
|
||||
|
||||
pagination, err := utils.Paginate(c, query, &clients)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": clients,
|
||||
"pagination": pagination,
|
||||
})
|
||||
}
|
||||
|
||||
func createClientHandler(c *gin.Context) {
|
||||
var input model.OidcClientCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
client := model.OidcClient{
|
||||
Name: input.Name,
|
||||
CallbackURL: input.CallbackURL,
|
||||
CreatedByID: c.GetString("userID"),
|
||||
}
|
||||
|
||||
if err := common.DB.Create(&client).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, client)
|
||||
}
|
||||
|
||||
func deleteClientHandler(c *gin.Context) {
|
||||
var client model.OidcClient
|
||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
||||
utils.HandlerError(c, http.StatusNotFound, "OIDC OIDC client not found")
|
||||
return
|
||||
}
|
||||
|
||||
if err := common.DB.Delete(&client).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func updateClientHandler(c *gin.Context) {
|
||||
var input model.OidcClientCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
var client model.OidcClient
|
||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
||||
return
|
||||
}
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
client.Name = input.Name
|
||||
client.CallbackURL = input.CallbackURL
|
||||
|
||||
if err := common.DB.Save(&client).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNoContent, client)
|
||||
}
|
||||
|
||||
// createClientSecretHandler creates a new secret for the client and revokes the old one
|
||||
func createClientSecretHandler(c *gin.Context) {
|
||||
var client model.OidcClient
|
||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
||||
return
|
||||
}
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
clientSecret, err := utils.GenerateRandomAlphanumericString(32)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
client.Secret = string(hashedSecret)
|
||||
if err := common.DB.Save(&client).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"secret": clientSecret})
|
||||
}
|
||||
|
||||
func getClientLogoHandler(c *gin.Context) {
|
||||
var client model.OidcClient
|
||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
||||
return
|
||||
}
|
||||
|
||||
if client.ImageType == nil {
|
||||
utils.HandlerError(c, http.StatusNotFound, "image not found")
|
||||
return
|
||||
}
|
||||
|
||||
imageType := *client.ImageType
|
||||
|
||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType)
|
||||
mimeType := utils.GetImageMimeType(imageType)
|
||||
|
||||
c.Header("Content-Type", mimeType)
|
||||
c.File(imagePath)
|
||||
}
|
||||
|
||||
func updateClientLogoHandler(c *gin.Context) {
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
fileType := utils.GetFileExtension(file.Filename)
|
||||
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "file type not supported")
|
||||
return
|
||||
}
|
||||
|
||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, c.Param("id"), fileType)
|
||||
err = c.SaveUploadedFile(file, imagePath)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
var client model.OidcClient
|
||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the old image if it has a different file type
|
||||
if client.ImageType != nil && fileType != *client.ImageType {
|
||||
oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
|
||||
if err := os.Remove(oldImagePath); err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
client.ImageType = &fileType
|
||||
if err := common.DB.Save(&client).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func deleteClientLogoHandler(c *gin.Context) {
|
||||
var client model.OidcClient
|
||||
if err := common.DB.First(&client, "id = ?", c.Param("id")).Error; err != nil {
|
||||
utils.HandlerError(c, http.StatusNotFound, "OIDC client not found")
|
||||
return
|
||||
}
|
||||
|
||||
if client.ImageType == nil {
|
||||
utils.HandlerError(c, http.StatusNotFound, "image not found")
|
||||
return
|
||||
}
|
||||
|
||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
|
||||
if err := os.Remove(imagePath); err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
client.ImageType = nil
|
||||
if err := common.DB.Save(&client).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
oidcAuthorizationCode := model.OidcAuthorizationCode{
|
||||
ExpiresAt: time.Now().Add(15 * time.Minute),
|
||||
Code: randomString,
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Scope: scope,
|
||||
Nonce: nonce,
|
||||
}
|
||||
|
||||
if err := common.DB.Create(&oidcAuthorizationCode).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return randomString, nil
|
||||
}
|
||||
@@ -1,276 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/common/middleware"
|
||||
"golang-rest-api-template/internal/model"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"golang.org/x/time/rate"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func RegisterUserRoutes(group *gin.RouterGroup) {
|
||||
group.GET("/users", middleware.JWTAuth(true), listUsersHandler)
|
||||
group.GET("/users/me", middleware.JWTAuth(false), getCurrentUserHandler)
|
||||
group.GET("/users/:id", middleware.JWTAuth(true), getUserHandler)
|
||||
group.POST("/users", middleware.JWTAuth(true), createUserHandler)
|
||||
group.PUT("/users/:id", middleware.JWTAuth(true), updateUserHandler)
|
||||
group.PUT("/users/me", middleware.JWTAuth(false), updateCurrentUserHandler)
|
||||
group.DELETE("/users/:id", middleware.JWTAuth(true), deleteUserHandler)
|
||||
|
||||
group.POST("/users/:id/one-time-access-token", middleware.JWTAuth(true), createOneTimeAccessTokenHandler)
|
||||
group.POST("/one-time-access-token/:token", middleware.RateLimiter(rate.Every(10*time.Second), 5), exchangeOneTimeAccessTokenHandler)
|
||||
group.POST("/one-time-access-token/setup", getSetupAccessTokenHandler)
|
||||
}
|
||||
|
||||
func listUsersHandler(c *gin.Context) {
|
||||
var users []model.User
|
||||
searchTerm := c.Query("search")
|
||||
|
||||
query := common.DB.Model(&model.User{})
|
||||
|
||||
if searchTerm != "" {
|
||||
searchPattern := "%" + searchTerm + "%"
|
||||
query = query.Where("email LIKE ? OR first_name LIKE ? OR username LIKE ?", searchPattern, searchPattern, searchPattern)
|
||||
}
|
||||
|
||||
pagination, err := utils.Paginate(c, query, &users)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": users,
|
||||
"pagination": pagination,
|
||||
})
|
||||
}
|
||||
|
||||
func getUserHandler(c *gin.Context) {
|
||||
var user model.User
|
||||
if err := common.DB.Where("id = ?", c.Param("id")).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
utils.HandlerError(c, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
func getCurrentUserHandler(c *gin.Context) {
|
||||
var user model.User
|
||||
if err := common.DB.Where("id = ?", c.GetString("userID")).First(&user).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, user)
|
||||
|
||||
}
|
||||
|
||||
func deleteUserHandler(c *gin.Context) {
|
||||
var user model.User
|
||||
if err := common.DB.Where("id = ?", c.Param("id")).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
utils.HandlerError(c, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := common.DB.Delete(&user).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func createUserHandler(c *gin.Context) {
|
||||
var user model.User
|
||||
if err := c.ShouldBindJSON(&user); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if err := common.DB.Create(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
if err := checkDuplicatedFields(user); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, user)
|
||||
}
|
||||
|
||||
func updateUserHandler(c *gin.Context) {
|
||||
updateUser(c, c.Param("id"), false)
|
||||
}
|
||||
|
||||
func updateCurrentUserHandler(c *gin.Context) {
|
||||
updateUser(c, c.GetString("userID"), true)
|
||||
}
|
||||
|
||||
func createOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
var input model.OneTimeAccessTokenCreateDto
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(16)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
oneTimeAccessToken := model.OneTimeAccessToken{
|
||||
UserID: input.UserID,
|
||||
ExpiresAt: input.ExpiresAt,
|
||||
Token: randomString,
|
||||
}
|
||||
|
||||
if err := common.DB.Create(&oneTimeAccessToken).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{"token": oneTimeAccessToken.Token})
|
||||
}
|
||||
|
||||
func exchangeOneTimeAccessTokenHandler(c *gin.Context) {
|
||||
var oneTimeAccessToken model.OneTimeAccessToken
|
||||
if err := common.DB.Where("token = ? AND expires_at > ?", c.Param("token"), utils.FormatDateForDb(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
utils.HandlerError(c, http.StatusForbidden, "Token is invalid or expired")
|
||||
return
|
||||
}
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := common.GenerateAccessToken(oneTimeAccessToken.User)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := common.DB.Delete(&oneTimeAccessToken).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
||||
|
||||
c.JSON(http.StatusOK, oneTimeAccessToken.User)
|
||||
}
|
||||
|
||||
// getSetupAccessTokenHandler creates the initial admin user and returns an access token for the user
|
||||
// This handler is only available if there are no users in the database
|
||||
func getSetupAccessTokenHandler(c *gin.Context) {
|
||||
var userCount int64
|
||||
if err := common.DB.Model(&model.User{}).Count(&userCount).Error; err != nil {
|
||||
log.Fatal("failed to count users", err)
|
||||
}
|
||||
|
||||
// If there are more than one user, we don't need to create the admin user
|
||||
if userCount > 1 {
|
||||
utils.HandlerError(c, http.StatusForbidden, "Setup already completed")
|
||||
return
|
||||
}
|
||||
|
||||
var user = model.User{
|
||||
FirstName: "Admin",
|
||||
LastName: "Admin",
|
||||
Username: "admin",
|
||||
Email: "admin@admin.com",
|
||||
IsAdmin: true,
|
||||
}
|
||||
|
||||
// Create the initial admin user if it doesn't exist
|
||||
if err := common.DB.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
|
||||
log.Fatal("failed to create admin user", err)
|
||||
}
|
||||
|
||||
// If the user already has credentials, the setup is already completed
|
||||
if len(user.Credentials) > 0 {
|
||||
utils.HandlerError(c, http.StatusForbidden, "Setup already completed")
|
||||
return
|
||||
}
|
||||
|
||||
token, err := common.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
func updateUser(c *gin.Context, userID string, updateOwnUser bool) {
|
||||
var user model.User
|
||||
if err := common.DB.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
utils.HandlerError(c, http.StatusNotFound, "User not found")
|
||||
return
|
||||
}
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
var updatedUser model.User
|
||||
if err := c.ShouldBindJSON(&updatedUser); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
user.FirstName = updatedUser.FirstName
|
||||
user.LastName = updatedUser.LastName
|
||||
user.Email = updatedUser.Email
|
||||
user.Username = updatedUser.Username
|
||||
user.Username = updatedUser.Username
|
||||
if !updateOwnUser {
|
||||
user.IsAdmin = updatedUser.IsAdmin
|
||||
}
|
||||
|
||||
if err := common.DB.Save(user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
if err := checkDuplicatedFields(user); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
func checkDuplicatedFields(user model.User) error {
|
||||
var existingUser model.User
|
||||
|
||||
if common.DB.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil {
|
||||
return errors.New("email is already taken")
|
||||
}
|
||||
|
||||
if common.DB.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil {
|
||||
return errors.New("username is already taken")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,257 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/common/middleware"
|
||||
"golang-rest-api-template/internal/model"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"golang.org/x/time/rate"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func RegisterRoutes(group *gin.RouterGroup) {
|
||||
group.GET("/webauthn/register/start", middleware.JWTAuth(false), beginRegistrationHandler)
|
||||
group.POST("/webauthn/register/finish", middleware.JWTAuth(false), verifyRegistrationHandler)
|
||||
|
||||
group.GET("/webauthn/login/start", beginLoginHandler)
|
||||
group.POST("/webauthn/login/finish", middleware.RateLimiter(rate.Every(10*time.Second), 5), verifyLoginHandler)
|
||||
|
||||
group.POST("/webauthn/logout", middleware.JWTAuth(false), logoutHandler)
|
||||
|
||||
group.GET("/webauthn/credentials", middleware.JWTAuth(false), listCredentialsHandler)
|
||||
group.PATCH("/webauthn/credentials/:id", middleware.JWTAuth(false), updateCredentialHandler)
|
||||
group.DELETE("/webauthn/credentials/:id", middleware.JWTAuth(false), deleteCredentialHandler)
|
||||
}
|
||||
|
||||
func beginRegistrationHandler(c *gin.Context) {
|
||||
var user model.User
|
||||
err := common.DB.Preload("Credentials").Find(&user, "id = ?", c.GetString("userID")).Error
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
options, session, err := common.WebAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()))
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Save the webauthn session so we can retrieve it in the verifyRegistrationHandler
|
||||
sessionToStore := &model.WebauthnSession{
|
||||
ExpiresAt: session.Expires,
|
||||
Challenge: session.Challenge,
|
||||
UserVerification: string(session.UserVerification),
|
||||
}
|
||||
|
||||
if err = common.DB.Create(&sessionToStore).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie("session_id", sessionToStore.ID, int(common.WebAuthn.Config.Timeouts.Registration.Timeout.Seconds()), "/", "", false, true)
|
||||
c.JSON(http.StatusOK, options.Response)
|
||||
}
|
||||
|
||||
func verifyRegistrationHandler(c *gin.Context) {
|
||||
sessionID, err := c.Cookie("session_id")
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve the session that was previously created by the beginRegistrationHandler
|
||||
var storedSession model.WebauthnSession
|
||||
err = common.DB.First(&storedSession, "id = ?", sessionID).Error
|
||||
|
||||
session := webauthn.SessionData{
|
||||
Challenge: storedSession.Challenge,
|
||||
Expires: storedSession.ExpiresAt,
|
||||
UserID: []byte(c.GetString("userID")),
|
||||
}
|
||||
|
||||
var user model.User
|
||||
err = common.DB.Find(&user, "id = ?", c.GetString("userID")).Error
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
credential, err := common.WebAuthn.FinishRegistration(&user, session, c.Request)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
credentialToStore := model.WebauthnCredential{
|
||||
Name: "New Passkey",
|
||||
CredentialID: string(credential.ID),
|
||||
AttestationType: credential.AttestationType,
|
||||
PublicKey: credential.PublicKey,
|
||||
Transport: credential.Transport,
|
||||
UserID: user.ID,
|
||||
BackupEligible: credential.Flags.BackupEligible,
|
||||
BackupState: credential.Flags.BackupState,
|
||||
}
|
||||
if err := common.DB.Create(&credentialToStore).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, credentialToStore)
|
||||
}
|
||||
|
||||
func beginLoginHandler(c *gin.Context) {
|
||||
options, session, err := common.WebAuthn.BeginDiscoverableLogin()
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Save the webauthn session so we can retrieve it in the verifyLoginHandler
|
||||
sessionToStore := &model.WebauthnSession{
|
||||
ExpiresAt: session.Expires,
|
||||
Challenge: session.Challenge,
|
||||
UserVerification: string(session.UserVerification),
|
||||
}
|
||||
|
||||
if err = common.DB.Create(&sessionToStore).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie("session_id", sessionToStore.ID, int(common.WebAuthn.Config.Timeouts.Registration.Timeout.Seconds()), "/", "", false, true)
|
||||
c.JSON(http.StatusOK, options.Response)
|
||||
}
|
||||
|
||||
func verifyLoginHandler(c *gin.Context) {
|
||||
sessionID, err := c.Cookie("session_id")
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "Session ID missing")
|
||||
return
|
||||
}
|
||||
|
||||
credentialAssertionData, err := protocol.ParseCredentialRequestResponseBody(c.Request.Body)
|
||||
if err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "Invalid body")
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve the session that was previously created by the beginLoginHandler
|
||||
var storedSession model.WebauthnSession
|
||||
if err := common.DB.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
session := webauthn.SessionData{
|
||||
Challenge: storedSession.Challenge,
|
||||
Expires: storedSession.ExpiresAt,
|
||||
}
|
||||
|
||||
var user *model.User
|
||||
_, err = common.WebAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
|
||||
if err := common.DB.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}, session, credentialAssertionData)
|
||||
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), gorm.ErrRecordNotFound.Error()) {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "no user with this passkey exists")
|
||||
} else {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err = common.DB.Find(&user, "id = ?", c.GetString("userID")).Error
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := common.GenerateAccessToken(*user)
|
||||
if err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.SetCookie("access_token", token, int(time.Hour.Seconds()), "/", "", false, true)
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
func listCredentialsHandler(c *gin.Context) {
|
||||
var credentials []model.WebauthnCredential
|
||||
if err := common.DB.Find(&credentials, "user_id = ?", c.GetString("userID")).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, credentials)
|
||||
}
|
||||
|
||||
func deleteCredentialHandler(c *gin.Context) {
|
||||
var passkeyCount int64
|
||||
if err := common.DB.Model(&model.WebauthnCredential{}).Where("user_id = ?", c.GetString("userID")).Count(&passkeyCount).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if passkeyCount == 1 {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "You must have at least one passkey")
|
||||
return
|
||||
}
|
||||
|
||||
var credential model.WebauthnCredential
|
||||
if err := common.DB.First(&credential, "id = ? AND user_id = ?", c.Param("id"), c.GetString("userID")).Error; err != nil {
|
||||
utils.HandlerError(c, http.StatusNotFound, "Credential not found")
|
||||
return
|
||||
}
|
||||
|
||||
if err := common.DB.Delete(&credential).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func updateCredentialHandler(c *gin.Context) {
|
||||
var credential model.WebauthnCredential
|
||||
if err := common.DB.Where("id = ? AND user_id = ?", c.Param("id"), c.GetString("userID")).First(&credential).Error; err != nil {
|
||||
utils.HandlerError(c, http.StatusNotFound, "Credential not found")
|
||||
return
|
||||
}
|
||||
|
||||
var input struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
utils.HandlerError(c, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
credential.Name = input.Name
|
||||
|
||||
if err := common.DB.Save(&credential).Error; err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func logoutHandler(c *gin.Context) {
|
||||
c.SetCookie("access_token", "", 0, "/", "", false, true)
|
||||
c.Status(http.StatusNoContent)
|
||||
}
|
||||
@@ -3,28 +3,46 @@ package job
|
||||
import (
|
||||
"github.com/go-co-op/gocron/v2"
|
||||
"github.com/google/uuid"
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/model"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
func RegisterJobs() {
|
||||
func RegisterJobs(db *gorm.DB) {
|
||||
scheduler, err := gocron.NewScheduler()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create a new scheduler: %s", err)
|
||||
}
|
||||
|
||||
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", clearWebauthnSessions)
|
||||
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", clearOneTimeAccessTokens)
|
||||
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", clearOidcAuthorizationCodes)
|
||||
jobs := &Jobs{db: db}
|
||||
|
||||
registerJob(scheduler, "ClearWebauthnSessions", "0 3 * * *", jobs.clearWebauthnSessions)
|
||||
registerJob(scheduler, "ClearOneTimeAccessTokens", "0 3 * * *", jobs.clearOneTimeAccessTokens)
|
||||
registerJob(scheduler, "ClearOidcAuthorizationCodes", "0 3 * * *", jobs.clearOidcAuthorizationCodes)
|
||||
|
||||
scheduler.Start()
|
||||
}
|
||||
|
||||
func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) {
|
||||
type Jobs struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func (j *Jobs) clearWebauthnSessions() error {
|
||||
return j.db.Delete(&model.WebauthnSession{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
||||
}
|
||||
|
||||
func (j *Jobs) clearOneTimeAccessTokens() error {
|
||||
return j.db.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
||||
}
|
||||
|
||||
func (j *Jobs) clearOidcAuthorizationCodes() error {
|
||||
return j.db.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
||||
|
||||
}
|
||||
|
||||
func registerJob(scheduler gocron.Scheduler, name string, interval string, job func() error) {
|
||||
_, err := scheduler.NewJob(
|
||||
gocron.CronJob(interval, false),
|
||||
gocron.NewTask(job),
|
||||
@@ -42,16 +60,3 @@ func registerJob(scheduler gocron.Scheduler, name string, interval string, job f
|
||||
log.Fatalf("Failed to register job %q: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
func clearWebauthnSessions() error {
|
||||
return common.DB.Delete(&model.WebauthnSession{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
||||
}
|
||||
|
||||
func clearOneTimeAccessTokens() error {
|
||||
return common.DB.Debug().Delete(&model.OneTimeAccessToken{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
||||
}
|
||||
|
||||
func clearOidcAuthorizationCodes() error {
|
||||
return common.DB.Delete(&model.OidcAuthorizationCode{}, "expires_at < ?", utils.FormatDateForDb(time.Now())).Error
|
||||
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"golang-rest-api-template/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Cors() gin.HandlerFunc {
|
||||
type CorsMiddleware struct{}
|
||||
|
||||
func NewCorsMiddleware() *CorsMiddleware {
|
||||
return &CorsMiddleware{}
|
||||
}
|
||||
|
||||
func (m *CorsMiddleware) Add() gin.HandlerFunc {
|
||||
return cors.New(cors.Config{
|
||||
AllowOrigins: []string{common.EnvConfig.AppURL},
|
||||
AllowMethods: []string{"*"},
|
||||
@@ -3,15 +3,21 @@ package middleware
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func LimitFileSize(maxSize int64) gin.HandlerFunc {
|
||||
type FileSizeLimitMiddleware struct{}
|
||||
|
||||
func NewFileSizeLimitMiddleware() *FileSizeLimitMiddleware {
|
||||
return &FileSizeLimitMiddleware{}
|
||||
}
|
||||
|
||||
func (m *FileSizeLimitMiddleware) Add(maxSize int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxSize)
|
||||
if err := c.Request.ParseMultipartForm(maxSize); err != nil {
|
||||
utils.HandlerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize)))
|
||||
utils.CustomControllerError(c, http.StatusRequestEntityTooLarge, fmt.Sprintf("The file can't be larger than %s bytes", formatFileSize(maxSize)))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
59
backend/internal/middleware/jwt_auth.go
Normal file
59
backend/internal/middleware/jwt_auth.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stonith404/pocket-id/backend/internal/service"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type JwtAuthMiddleware struct {
|
||||
jwtService *service.JwtService
|
||||
ignoreUnauthenticated bool
|
||||
}
|
||||
|
||||
func NewJwtAuthMiddleware(jwtService *service.JwtService, ignoreUnauthenticated bool) *JwtAuthMiddleware {
|
||||
return &JwtAuthMiddleware{jwtService: jwtService, ignoreUnauthenticated: ignoreUnauthenticated}
|
||||
}
|
||||
|
||||
func (m *JwtAuthMiddleware) Add(adminOnly bool) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Extract the token from the cookie or the Authorization header
|
||||
token, err := c.Cookie("access_token")
|
||||
if err != nil {
|
||||
authorizationHeaderSplitted := strings.Split(c.GetHeader("Authorization"), " ")
|
||||
if len(authorizationHeaderSplitted) == 2 {
|
||||
token = authorizationHeaderSplitted[1]
|
||||
} else if m.ignoreUnauthenticated {
|
||||
c.Next()
|
||||
return
|
||||
} else {
|
||||
utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
claims, err := m.jwtService.VerifyAccessToken(token)
|
||||
if err != nil && m.ignoreUnauthenticated {
|
||||
c.Next()
|
||||
return
|
||||
} else if err != nil {
|
||||
utils.CustomControllerError(c, http.StatusUnauthorized, "You're not signed in")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the user is an admin
|
||||
if adminOnly && !claims.IsAdmin {
|
||||
utils.CustomControllerError(c, http.StatusForbidden, "You don't have permission to access this resource")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("userID", claims.Subject)
|
||||
c.Set("userIsAdmin", claims.IsAdmin)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -11,8 +11,13 @@ import (
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimiter is a Gin middleware for rate limiting based on client IP
|
||||
func RateLimiter(limit rate.Limit, burst int) gin.HandlerFunc {
|
||||
type RateLimitMiddleware struct{}
|
||||
|
||||
func NewRateLimitMiddleware() *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{}
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) Add(limit rate.Limit, burst int) gin.HandlerFunc {
|
||||
// Start the cleanup routine
|
||||
go cleanupClients()
|
||||
|
||||
@@ -28,7 +33,7 @@ func RateLimiter(limit rate.Limit, burst int) gin.HandlerFunc {
|
||||
|
||||
limiter := getLimiter(ip, limit, burst)
|
||||
if !limiter.Allow() {
|
||||
utils.HandlerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.")
|
||||
utils.CustomControllerError(c, http.StatusTooManyRequests, "Too many requests. Please wait a while before trying again.")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
16
backend/internal/model/app_config.go
Normal file
16
backend/internal/model/app_config.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package model
|
||||
|
||||
type AppConfigVariable struct {
|
||||
Key string `gorm:"primaryKey;not null"`
|
||||
Type string
|
||||
IsPublic bool
|
||||
IsInternal bool
|
||||
Value string
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
AppName AppConfigVariable
|
||||
BackgroundImageType AppConfigVariable
|
||||
LogoImageType AppConfigVariable
|
||||
SessionDuration AppConfigVariable
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package model
|
||||
|
||||
type ApplicationConfigurationVariable struct {
|
||||
Key string `gorm:"primaryKey;not null" json:"key"`
|
||||
Type string `json:"type"`
|
||||
IsPublic bool `json:"-"`
|
||||
IsInternal bool `json:"-"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type ApplicationConfiguration struct {
|
||||
AppName ApplicationConfigurationVariable
|
||||
BackgroundImageType ApplicationConfigurationVariable
|
||||
LogoImageType ApplicationConfigurationVariable
|
||||
SessionDuration ApplicationConfigurationVariable
|
||||
}
|
||||
|
||||
type ApplicationConfigurationUpdateDto struct {
|
||||
AppName string `json:"appName" binding:"required"`
|
||||
}
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
|
||||
// Base contains common columns for all tables.
|
||||
type Base struct {
|
||||
ID string `gorm:"primaryKey;not null" json:"id"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ID string `gorm:"primaryKey;not null"`
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
func (b *Base) BeforeCreate(db *gorm.DB) (err error) {
|
||||
func (b *Base) BeforeCreate(_ *gorm.DB) (err error) {
|
||||
if b.ID == "" {
|
||||
b.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
@@ -1,37 +1,22 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserAuthorizedOidcClient struct {
|
||||
Scope string
|
||||
UserID string `json:"userId" gorm:"primary_key;"`
|
||||
UserID string `gorm:"primary_key;"`
|
||||
User User
|
||||
|
||||
ClientID string `json:"clientId" gorm:"primary_key;"`
|
||||
ClientID string `gorm:"primary_key;"`
|
||||
Client OidcClient
|
||||
}
|
||||
|
||||
type OidcClient struct {
|
||||
Base
|
||||
|
||||
Name string `json:"name"`
|
||||
Secret string `json:"-"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
ImageType *string `json:"-"`
|
||||
HasLogo bool `gorm:"-" json:"hasLogo"`
|
||||
|
||||
CreatedByID string
|
||||
CreatedBy User
|
||||
}
|
||||
|
||||
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
|
||||
// Compute HasLogo field
|
||||
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
|
||||
return nil
|
||||
}
|
||||
|
||||
type OidcAuthorizationCode struct {
|
||||
Base
|
||||
|
||||
@@ -46,20 +31,35 @@ type OidcAuthorizationCode struct {
|
||||
ClientID string
|
||||
}
|
||||
|
||||
type OidcClientCreateDto struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
CallbackURL string `json:"callbackURL" binding:"required"`
|
||||
type OidcClient struct {
|
||||
Base
|
||||
|
||||
Name string
|
||||
Secret string
|
||||
CallbackURLs CallbackURLs
|
||||
ImageType *string
|
||||
HasLogo bool `gorm:"-"`
|
||||
|
||||
CreatedByID string
|
||||
CreatedBy User
|
||||
}
|
||||
|
||||
type AuthorizeNewClientDto struct {
|
||||
ClientID string `json:"clientID" binding:"required"`
|
||||
Scope string `json:"scope" binding:"required"`
|
||||
Nonce string `json:"nonce"`
|
||||
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
|
||||
// Compute HasLogo field
|
||||
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
|
||||
return nil
|
||||
}
|
||||
|
||||
type OidcIdTokenDto struct {
|
||||
GrantType string `form:"grant_type" binding:"required"`
|
||||
Code string `form:"code" binding:"required"`
|
||||
ClientID string `form:"client_id"`
|
||||
ClientSecret string `form:"client_secret"`
|
||||
type CallbackURLs []string
|
||||
|
||||
func (cu *CallbackURLs) Scan(value interface{}) error {
|
||||
if v, ok := value.([]byte); ok {
|
||||
return json.Unmarshal(v, cu)
|
||||
} else {
|
||||
return errors.New("type assertion to []byte failed")
|
||||
}
|
||||
}
|
||||
|
||||
func (cu CallbackURLs) Value() (driver.Value, error) {
|
||||
return json.Marshal(cu)
|
||||
}
|
||||
|
||||
@@ -9,13 +9,13 @@ import (
|
||||
type User struct {
|
||||
Base
|
||||
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email" `
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
IsAdmin bool `json:"isAdmin"`
|
||||
Username string
|
||||
Email string
|
||||
FirstName string
|
||||
LastName string
|
||||
IsAdmin bool
|
||||
|
||||
Credentials []WebauthnCredential `json:"-"`
|
||||
Credentials []WebauthnCredential
|
||||
}
|
||||
|
||||
func (u User) WebAuthnID() []byte { return []byte(u.ID) }
|
||||
@@ -59,19 +59,9 @@ func (u User) WebAuthnCredentialDescriptors() (descriptors []protocol.Credential
|
||||
|
||||
type OneTimeAccessToken struct {
|
||||
Base
|
||||
Token string `json:"token"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
|
||||
UserID string `json:"userId"`
|
||||
UserID string
|
||||
User User
|
||||
}
|
||||
|
||||
type OneTimeAccessTokenCreateDto struct {
|
||||
UserID string `json:"userId" binding:"required"`
|
||||
ExpiresAt time.Time `json:"expiresAt" binding:"required"`
|
||||
}
|
||||
|
||||
type LoginUserDto struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
|
||||
@@ -19,11 +19,11 @@ type WebauthnSession struct {
|
||||
type WebauthnCredential struct {
|
||||
Base
|
||||
|
||||
Name string `json:"name"`
|
||||
CredentialID string `json:"credentialID"`
|
||||
PublicKey []byte `json:"-"`
|
||||
AttestationType string `json:"attestationType"`
|
||||
Transport AuthenticatorTransportList `json:"-"`
|
||||
Name string
|
||||
CredentialID string
|
||||
PublicKey []byte
|
||||
AttestationType string
|
||||
Transport AuthenticatorTransportList
|
||||
|
||||
BackupEligible bool `json:"backupEligible"`
|
||||
BackupState bool `json:"backupState"`
|
||||
@@ -31,6 +31,18 @@ type WebauthnCredential struct {
|
||||
UserID string
|
||||
}
|
||||
|
||||
type PublicKeyCredentialCreationOptions struct {
|
||||
Response protocol.PublicKeyCredentialCreationOptions
|
||||
SessionID string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type PublicKeyCredentialRequestOptions struct {
|
||||
Response protocol.PublicKeyCredentialRequestOptions
|
||||
SessionID string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type AuthenticatorTransportList []protocol.AuthenticatorTransport
|
||||
|
||||
// Scan and Value methods for GORM to handle the custom type
|
||||
|
||||
214
backend/internal/service/app_config_service.go
Normal file
214
backend/internal/service/app_config_service.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type AppConfigService struct {
|
||||
DbConfig *model.AppConfig
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAppConfigService(db *gorm.DB) *AppConfigService {
|
||||
service := &AppConfigService{
|
||||
DbConfig: &defaultDbConfig,
|
||||
db: db,
|
||||
}
|
||||
if err := service.InitDbConfig(); err != nil {
|
||||
log.Fatalf("Failed to initialize app config service: %v", err)
|
||||
}
|
||||
return service
|
||||
}
|
||||
|
||||
var defaultDbConfig = model.AppConfig{
|
||||
AppName: model.AppConfigVariable{
|
||||
Key: "appName",
|
||||
Type: "string",
|
||||
IsPublic: true,
|
||||
Value: "Pocket ID",
|
||||
},
|
||||
SessionDuration: model.AppConfigVariable{
|
||||
Key: "sessionDuration",
|
||||
Type: "number",
|
||||
Value: "60",
|
||||
},
|
||||
BackgroundImageType: model.AppConfigVariable{
|
||||
Key: "backgroundImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
Value: "jpg",
|
||||
},
|
||||
LogoImageType: model.AppConfigVariable{
|
||||
Key: "logoImageType",
|
||||
Type: "string",
|
||||
IsInternal: true,
|
||||
Value: "svg",
|
||||
},
|
||||
}
|
||||
|
||||
func (s *AppConfigService) UpdateApplicationConfiguration(input dto.AppConfigUpdateDto) ([]model.AppConfigVariable, error) {
|
||||
var savedConfigVariables []model.AppConfigVariable
|
||||
|
||||
tx := s.db.Begin()
|
||||
rt := reflect.ValueOf(input).Type()
|
||||
rv := reflect.ValueOf(input)
|
||||
|
||||
for i := 0; i < rt.NumField(); i++ {
|
||||
field := rt.Field(i)
|
||||
key := field.Tag.Get("json")
|
||||
value := rv.FieldByName(field.Name).String()
|
||||
|
||||
var applicationConfigurationVariable model.AppConfigVariable
|
||||
if err := tx.First(&applicationConfigurationVariable, "key = ? AND is_internal = false", key).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
applicationConfigurationVariable.Value = value
|
||||
if err := tx.Save(&applicationConfigurationVariable).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
savedConfigVariables = append(savedConfigVariables, applicationConfigurationVariable)
|
||||
}
|
||||
|
||||
tx.Commit()
|
||||
|
||||
if err := s.loadDbConfigFromDb(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return savedConfigVariables, nil
|
||||
}
|
||||
|
||||
func (s *AppConfigService) UpdateImageType(imageName string, fileType string) error {
|
||||
key := fmt.Sprintf("%sImageType", imageName)
|
||||
err := s.db.Model(&model.AppConfigVariable{}).Where("key = ?", key).Update("value", fileType).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.loadDbConfigFromDb()
|
||||
}
|
||||
|
||||
func (s *AppConfigService) ListApplicationConfiguration(showAll bool) ([]model.AppConfigVariable, error) {
|
||||
var configuration []model.AppConfigVariable
|
||||
var err error
|
||||
|
||||
if showAll {
|
||||
err = s.db.Find(&configuration).Error
|
||||
} else {
|
||||
err = s.db.Find(&configuration, "is_public = true").Error
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return configuration, nil
|
||||
}
|
||||
|
||||
func (s *AppConfigService) UpdateImage(uploadedFile *multipart.FileHeader, imageName string, oldImageType string) error {
|
||||
fileType := utils.GetFileExtension(uploadedFile.Filename)
|
||||
mimeType := utils.GetImageMimeType(fileType)
|
||||
if mimeType == "" {
|
||||
return common.ErrFileTypeNotSupported
|
||||
}
|
||||
|
||||
// Delete the old image if it has a different file type
|
||||
if fileType != oldImageType {
|
||||
oldImagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, oldImageType)
|
||||
if err := os.Remove(oldImagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
imagePath := fmt.Sprintf("%s/application-images/%s.%s", common.EnvConfig.UploadPath, imageName, fileType)
|
||||
if err := utils.SaveFile(uploadedFile, imagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the file type in the database
|
||||
if err := s.UpdateImageType(imageName, fileType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitDbConfig creates the default configuration values in the database if they do not exist,
|
||||
// updates existing configurations if they differ from the default, and deletes any configurations
|
||||
// that are not in the default configuration.
|
||||
func (s *AppConfigService) InitDbConfig() error {
|
||||
// Reflect to get the underlying value of DbConfig and its default configuration
|
||||
defaultConfigReflectValue := reflect.ValueOf(defaultDbConfig)
|
||||
defaultKeys := make(map[string]struct{})
|
||||
|
||||
// Iterate over the fields of DbConfig
|
||||
for i := 0; i < defaultConfigReflectValue.NumField(); i++ {
|
||||
defaultConfigVar := defaultConfigReflectValue.Field(i).Interface().(model.AppConfigVariable)
|
||||
|
||||
defaultKeys[defaultConfigVar.Key] = struct{}{}
|
||||
|
||||
var storedConfigVar model.AppConfigVariable
|
||||
if err := s.db.First(&storedConfigVar, "key = ?", defaultConfigVar.Key).Error; err != nil {
|
||||
// If the configuration does not exist, create it
|
||||
if err := s.db.Create(&defaultConfigVar).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Update existing configuration if it differs from the default
|
||||
if storedConfigVar.Type != defaultConfigVar.Type || storedConfigVar.IsPublic != defaultConfigVar.IsPublic || storedConfigVar.IsInternal != defaultConfigVar.IsInternal {
|
||||
storedConfigVar.Type = defaultConfigVar.Type
|
||||
storedConfigVar.IsPublic = defaultConfigVar.IsPublic
|
||||
storedConfigVar.IsInternal = defaultConfigVar.IsInternal
|
||||
if err := s.db.Save(&storedConfigVar).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete any configurations not in the default keys
|
||||
var allConfigVars []model.AppConfigVariable
|
||||
if err := s.db.Find(&allConfigVars).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, config := range allConfigVars {
|
||||
if _, exists := defaultKeys[config.Key]; !exists {
|
||||
if err := s.db.Delete(&config).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return s.loadDbConfigFromDb()
|
||||
}
|
||||
|
||||
func (s *AppConfigService) loadDbConfigFromDb() error {
|
||||
dbConfigReflectValue := reflect.ValueOf(s.DbConfig).Elem()
|
||||
|
||||
for i := 0; i < dbConfigReflectValue.NumField(); i++ {
|
||||
dbConfigField := dbConfigReflectValue.Field(i)
|
||||
currentConfigVar := dbConfigField.Interface().(model.AppConfigVariable)
|
||||
var storedConfigVar model.AppConfigVariable
|
||||
if err := s.db.First(&storedConfigVar, "key = ?", currentConfigVar.Key).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbConfigField.Set(reflect.ValueOf(storedConfigVar))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
261
backend/internal/service/jwt_service.go
Normal file
261
backend/internal/service/jwt_service.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"log"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
privateKeyPath = "data/keys/jwt_private_key.pem"
|
||||
publicKeyPath = "data/keys/jwt_public_key.pem"
|
||||
)
|
||||
|
||||
type JwtService struct {
|
||||
publicKey *rsa.PublicKey
|
||||
privateKey *rsa.PrivateKey
|
||||
appConfigService *AppConfigService
|
||||
}
|
||||
|
||||
func NewJwtService(appConfigService *AppConfigService) *JwtService {
|
||||
service := &JwtService{
|
||||
appConfigService: appConfigService,
|
||||
}
|
||||
|
||||
// Ensure keys are generated or loaded
|
||||
if err := service.loadOrGenerateKeys(); err != nil {
|
||||
log.Fatalf("Failed to initialize jwt service: %v", err)
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
type AccessTokenJWTClaims struct {
|
||||
jwt.RegisteredClaims
|
||||
IsAdmin bool `json:"isAdmin,omitempty"`
|
||||
}
|
||||
|
||||
type JWK struct {
|
||||
Kty string `json:"kty"`
|
||||
Use string `json:"use"`
|
||||
Alg string `json:"alg"`
|
||||
N string `json:"n"`
|
||||
E string `json:"e"`
|
||||
}
|
||||
|
||||
// loadOrGenerateKeys loads RSA keys from the given paths or generates them if they do not exist.
|
||||
func (s *JwtService) loadOrGenerateKeys() error {
|
||||
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
|
||||
if err := s.generateKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
return errors.New("can't read jwt private key: " + err.Error())
|
||||
}
|
||||
s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
||||
if err != nil {
|
||||
return errors.New("can't parse jwt private key: " + err.Error())
|
||||
}
|
||||
|
||||
publicKeyBytes, err := os.ReadFile(publicKeyPath)
|
||||
if err != nil {
|
||||
return errors.New("can't read jwt public key: " + err.Error())
|
||||
}
|
||||
s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||
if err != nil {
|
||||
return errors.New("can't parse jwt public key: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateAccessToken(user model.User) (string, error) {
|
||||
sessionDurationInMinutes, _ := strconv.Atoi(s.appConfigService.DbConfig.SessionDuration.Value)
|
||||
claim := AccessTokenJWTClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: user.ID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(sessionDurationInMinutes) * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Audience: jwt.ClaimStrings{utils.GetHostFromURL(common.EnvConfig.AppURL)},
|
||||
},
|
||||
IsAdmin: user.IsAdmin,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||
return token.SignedString(s.privateKey)
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyAccessToken(tokenString string) (*AccessTokenJWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &AccessTokenJWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return s.publicKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
}
|
||||
|
||||
claims, isValid := token.Claims.(*AccessTokenJWTClaims)
|
||||
if !isValid {
|
||||
return nil, errors.New("can't parse claims")
|
||||
}
|
||||
|
||||
if !slices.Contains(claims.Audience, utils.GetHostFromURL(common.EnvConfig.AppURL)) {
|
||||
return nil, errors.New("audience doesn't match")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *JwtService) GenerateIDToken(userClaims map[string]interface{}, clientID string, nonce string) (string, error) {
|
||||
claims := jwt.MapClaims{
|
||||
"aud": clientID,
|
||||
"exp": jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
"iat": jwt.NewNumericDate(time.Now()),
|
||||
"iss": common.EnvConfig.AppURL,
|
||||
}
|
||||
|
||||
for k, v := range userClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
|
||||
if nonce != "" {
|
||||
claims["nonce"] = nonce
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
return token.SignedString(s.privateKey)
|
||||
}
|
||||
func (s *JwtService) GenerateOauthAccessToken(user model.User, clientID string) (string, error) {
|
||||
claim := jwt.RegisteredClaims{
|
||||
Subject: user.ID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(1 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Audience: jwt.ClaimStrings{clientID},
|
||||
Issuer: common.EnvConfig.AppURL,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claim)
|
||||
return token.SignedString(s.privateKey)
|
||||
}
|
||||
|
||||
func (s *JwtService) VerifyOauthAccessToken(tokenString string) (*jwt.RegisteredClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return s.publicKey, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("couldn't handle this token")
|
||||
}
|
||||
|
||||
claims, isValid := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !isValid {
|
||||
return nil, errors.New("can't parse claims")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// GetJWK returns the JSON Web Key (JWK) for the public key.
|
||||
func (s *JwtService) GetJWK() (JWK, error) {
|
||||
if s.publicKey == nil {
|
||||
return JWK{}, errors.New("public key is not initialized")
|
||||
}
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.publicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(s.publicKey.E)).Bytes()),
|
||||
}
|
||||
|
||||
return jwk, nil
|
||||
}
|
||||
|
||||
// generateKeys generates a new RSA key pair and saves them to the specified paths.
|
||||
func (s *JwtService) generateKeys() error {
|
||||
if err := os.MkdirAll(filepath.Dir(privateKeyPath), 0700); err != nil {
|
||||
return errors.New("failed to create directories for keys: " + err.Error())
|
||||
}
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return errors.New("failed to generate private key: " + err.Error())
|
||||
}
|
||||
s.privateKey = privateKey
|
||||
|
||||
if err := s.savePEMKey(privateKeyPath, x509.MarshalPKCS1PrivateKey(privateKey), "RSA PRIVATE KEY"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
s.publicKey = publicKey
|
||||
|
||||
if err := s.savePEMKey(publicKeyPath, x509.MarshalPKCS1PublicKey(publicKey), "RSA PUBLIC KEY"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// savePEMKey saves a PEM encoded key to a file.
|
||||
func (s *JwtService) savePEMKey(path string, keyBytes []byte, keyType string) error {
|
||||
keyFile, err := os.Create(path)
|
||||
if err != nil {
|
||||
return errors.New("failed to create key file: " + err.Error())
|
||||
}
|
||||
defer keyFile.Close()
|
||||
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: keyType,
|
||||
Bytes: keyBytes,
|
||||
})
|
||||
|
||||
if _, err := keyFile.Write(keyPEM); err != nil {
|
||||
return errors.New("failed to write key file: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadKeys loads RSA keys from the given paths.
|
||||
func (s *JwtService) loadKeys() error {
|
||||
if _, err := os.Stat(privateKeyPath); os.IsNotExist(err) {
|
||||
if err := s.generateKeys(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
privateKeyBytes, err := os.ReadFile(privateKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't read jwt private key: %w", err)
|
||||
}
|
||||
s.privateKey, err = jwt.ParseRSAPrivateKeyFromPEM(privateKeyBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't parse jwt private key: %w", err)
|
||||
}
|
||||
|
||||
publicKeyBytes, err := os.ReadFile(publicKeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't read jwt public key: %w", err)
|
||||
}
|
||||
s.publicKey, err = jwt.ParseRSAPublicKeyFromPEM(publicKeyBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't parse jwt public key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
352
backend/internal/service/oidc_service.go
Normal file
352
backend/internal/service/oidc_service.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OidcService struct {
|
||||
db *gorm.DB
|
||||
jwtService *JwtService
|
||||
}
|
||||
|
||||
func NewOidcService(db *gorm.DB, jwtService *JwtService) *OidcService {
|
||||
return &OidcService{
|
||||
db: db,
|
||||
jwtService: jwtService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID string) (string, string, error) {
|
||||
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
|
||||
s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID)
|
||||
|
||||
if userAuthorizedOIDCClient.Scope != input.Scope {
|
||||
return "", "", common.ErrOidcMissingAuthorization
|
||||
}
|
||||
|
||||
callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
|
||||
return code, callbackURL, err
|
||||
}
|
||||
|
||||
func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto, userID string) (string, string, error) {
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", input.ClientID).Error; err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
callbackURL, err := getCallbackURL(client, input.CallbackURL)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
userAuthorizedClient := model.UserAuthorizedOidcClient{
|
||||
UserID: userID,
|
||||
ClientID: input.ClientID,
|
||||
Scope: input.Scope,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&userAuthorizedClient).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
err = s.db.Model(&userAuthorizedClient).Update("scope", input.Scope).Error
|
||||
} else {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
|
||||
return code, callbackURL, err
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) {
|
||||
if grantType != "authorization_code" {
|
||||
return "", "", common.ErrOidcGrantTypeNotSupported
|
||||
}
|
||||
|
||||
if clientID == "" || clientSecret == "" {
|
||||
return "", "", common.ErrOidcMissingClientCredentials
|
||||
}
|
||||
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
|
||||
if err != nil {
|
||||
return "", "", common.ErrOidcClientSecretInvalid
|
||||
}
|
||||
|
||||
var authorizationCodeMetaData model.OidcAuthorizationCode
|
||||
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
|
||||
if err != nil {
|
||||
return "", "", common.ErrOidcInvalidAuthorizationCode
|
||||
}
|
||||
|
||||
if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.Before(time.Now()) {
|
||||
return "", "", common.ErrOidcInvalidAuthorizationCode
|
||||
}
|
||||
|
||||
userClaims, err := s.GetUserClaimsForClient(authorizationCodeMetaData.UserID, clientID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
idToken, err := s.jwtService.GenerateIDToken(userClaims, clientID, authorizationCodeMetaData.Nonce)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
accessToken, err := s.jwtService.GenerateOauthAccessToken(authorizationCodeMetaData.User, clientID)
|
||||
|
||||
s.db.Delete(&authorizationCodeMetaData)
|
||||
|
||||
return idToken, accessToken, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClient(clientID string) (model.OidcClient, error) {
|
||||
var client model.OidcClient
|
||||
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) ListClients(searchTerm string, page int, pageSize int) ([]model.OidcClient, utils.PaginationResponse, error) {
|
||||
var clients []model.OidcClient
|
||||
|
||||
query := s.db.Preload("CreatedBy").Model(&model.OidcClient{})
|
||||
if searchTerm != "" {
|
||||
searchPattern := "%" + searchTerm + "%"
|
||||
query = query.Where("name LIKE ?", searchPattern)
|
||||
}
|
||||
|
||||
pagination, err := utils.Paginate(page, pageSize, query, &clients)
|
||||
if err != nil {
|
||||
return nil, utils.PaginationResponse{}, err
|
||||
}
|
||||
|
||||
return clients, pagination, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateClient(input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||
client := model.OidcClient{
|
||||
Name: input.Name,
|
||||
CallbackURLs: input.CallbackURLs,
|
||||
CreatedByID: userID,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&client).Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDto) (model.OidcClient, error) {
|
||||
var client model.OidcClient
|
||||
if err := s.db.Preload("CreatedBy").First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
client.Name = input.Name
|
||||
client.CallbackURLs = input.CallbackURLs
|
||||
|
||||
if err := s.db.Save(&client).Error; err != nil {
|
||||
return model.OidcClient{}, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClient(clientID string) error {
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.db.Delete(&client).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) CreateClientSecret(clientID string) (string, error) {
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
clientSecret, err := utils.GenerateRandomAlphanumericString(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
client.Secret = string(hashedSecret)
|
||||
if err := s.db.Save(&client).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return clientSecret, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetClientLogo(clientID string) (string, string, error) {
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if client.ImageType == nil {
|
||||
return "", "", errors.New("image not found")
|
||||
}
|
||||
|
||||
imageType := *client.ImageType
|
||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, imageType)
|
||||
mimeType := utils.GetImageMimeType(imageType)
|
||||
|
||||
return imagePath, mimeType, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) UpdateClientLogo(clientID string, file *multipart.FileHeader) error {
|
||||
fileType := utils.GetFileExtension(file.Filename)
|
||||
if mimeType := utils.GetImageMimeType(fileType); mimeType == "" {
|
||||
return common.ErrFileTypeNotSupported
|
||||
}
|
||||
|
||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, clientID, fileType)
|
||||
if err := utils.SaveFile(file, imagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if client.ImageType != nil && fileType != *client.ImageType {
|
||||
oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
|
||||
if err := os.Remove(oldImagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
client.ImageType = &fileType
|
||||
if err := s.db.Save(&client).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) DeleteClientLogo(clientID string) error {
|
||||
var client model.OidcClient
|
||||
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if client.ImageType == nil {
|
||||
return errors.New("image not found")
|
||||
}
|
||||
|
||||
imagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
|
||||
if err := os.Remove(imagePath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.ImageType = nil
|
||||
if err := s.db.Save(&client).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (map[string]interface{}, error) {
|
||||
var authorizedOidcClient model.UserAuthorizedOidcClient
|
||||
if err := s.db.Preload("User").First(&authorizedOidcClient, "user_id = ? AND client_id = ?", userID, clientID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := authorizedOidcClient.User
|
||||
scope := authorizedOidcClient.Scope
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": user.ID,
|
||||
}
|
||||
|
||||
if strings.Contains(scope, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
profileClaims := map[string]interface{}{
|
||||
"given_name": user.FirstName,
|
||||
"family_name": user.LastName,
|
||||
"preferred_username": user.Username,
|
||||
}
|
||||
|
||||
if strings.Contains(scope, "profile") {
|
||||
for k, v := range profileClaims {
|
||||
claims[k] = v
|
||||
}
|
||||
}
|
||||
if strings.Contains(scope, "email") {
|
||||
claims["email"] = user.Email
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
oidcAuthorizationCode := model.OidcAuthorizationCode{
|
||||
ExpiresAt: time.Now().Add(15 * time.Minute),
|
||||
Code: randomString,
|
||||
ClientID: clientID,
|
||||
UserID: userID,
|
||||
Scope: scope,
|
||||
Nonce: nonce,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return randomString, nil
|
||||
}
|
||||
|
||||
func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
|
||||
if inputCallbackURL == "" {
|
||||
return client.CallbackURLs[0], nil
|
||||
}
|
||||
if slices.Contains(client.CallbackURLs, inputCallbackURL) {
|
||||
return inputCallbackURL, nil
|
||||
}
|
||||
|
||||
return "", common.ErrOidcInvalidCallbackURL
|
||||
}
|
||||
@@ -1,48 +1,33 @@
|
||||
package handler
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/fxamacker/cbor/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"golang-rest-api-template/internal/common"
|
||||
"golang-rest-api-template/internal/model"
|
||||
"golang-rest-api-template/internal/utils"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func RegisterTestRoutes(group *gin.RouterGroup) {
|
||||
group.POST("/test/reset", resetAndSeedHandler)
|
||||
type TestService struct {
|
||||
db *gorm.DB
|
||||
appConfigService *AppConfigService
|
||||
}
|
||||
|
||||
func resetAndSeedHandler(c *gin.Context) {
|
||||
if err := resetDatabase(); err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := resetApplicationImages(); err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := seedDatabase(); err != nil {
|
||||
utils.UnknownHandlerError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"message": "Database reset and seeded"})
|
||||
func NewTestService(db *gorm.DB, appConfigService *AppConfigService) *TestService {
|
||||
return &TestService{db: db, appConfigService: appConfigService}
|
||||
}
|
||||
|
||||
// seedDatabase seeds the database with initial data and uses a transaction to ensure atomicity.
|
||||
func seedDatabase() error {
|
||||
return common.DB.Transaction(func(tx *gorm.DB) error {
|
||||
func (s *TestService) SeedDatabase() error {
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
users := []model.User{
|
||||
{
|
||||
Base: model.Base{
|
||||
@@ -76,20 +61,20 @@ func seedDatabase() error {
|
||||
Base: model.Base{
|
||||
ID: "3654a746-35d4-4321-ac61-0bdcff2b4055",
|
||||
},
|
||||
Name: "Nextcloud",
|
||||
Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY
|
||||
CallbackURL: "http://nextcloud/auth/callback",
|
||||
ImageType: utils.StringPointer("png"),
|
||||
CreatedByID: users[0].ID,
|
||||
Name: "Nextcloud",
|
||||
Secret: "$2a$10$9dypwot8nGuCjT6wQWWpJOckZfRprhe2EkwpKizxS/fpVHrOLEJHC", // w2mUeZISmEvIDMEDvpY0PnxQIpj1m3zY
|
||||
CallbackURLs: model.CallbackURLs{"http://nextcloud/auth/callback"},
|
||||
ImageType: utils.StringPointer("png"),
|
||||
CreatedByID: users[0].ID,
|
||||
},
|
||||
{
|
||||
Base: model.Base{
|
||||
ID: "606c7782-f2b1-49e5-8ea9-26eb1b06d018",
|
||||
},
|
||||
Name: "Immich",
|
||||
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
||||
CallbackURL: "http://immich/auth/callback",
|
||||
CreatedByID: users[0].ID,
|
||||
Name: "Immich",
|
||||
Secret: "$2a$10$Ak.FP8riD1ssy2AGGbG.gOpnp/rBpymd74j0nxNMtW0GG1Lb4gzxe", // PYjrE9u4v9GVqXKi52eur0eb2Ci4kc0x
|
||||
CallbackURLs: model.CallbackURLs{"http://immich/auth/callback"},
|
||||
CreatedByID: users[0].ID,
|
||||
},
|
||||
}
|
||||
for _, client := range oidcClients {
|
||||
@@ -128,11 +113,16 @@ func seedDatabase() error {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKey1, err := getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg==")
|
||||
publicKey2, err := getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA==")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
webauthnCredentials := []model.WebauthnCredential{
|
||||
{
|
||||
Name: "Passkey 1",
|
||||
CredentialID: "test-credential-1",
|
||||
PublicKey: getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEwcOo5KV169KR67QEHrcYkeXE3CCxv2BgwnSq4VYTQxyLtdmKxegexa8JdwFKhKXa2BMI9xaN15BoL6wSCRFJhg=="),
|
||||
PublicKey: publicKey1,
|
||||
AttestationType: "none",
|
||||
Transport: model.AuthenticatorTransportList{protocol.Internal},
|
||||
UserID: users[0].ID,
|
||||
@@ -140,7 +130,7 @@ func seedDatabase() error {
|
||||
{
|
||||
Name: "Passkey 2",
|
||||
CredentialID: "test-credential-2",
|
||||
PublicKey: getCborPublicKey("MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAESq/wR8QbBu3dKnpaw/v0mDxFFDwnJ/L5XHSg2tAmq5x1BpSMmIr3+DxCbybVvGRmWGh8kKhy7SMnK91M6rFHTA=="),
|
||||
PublicKey: publicKey2,
|
||||
AttestationType: "none",
|
||||
Transport: model.AuthenticatorTransportList{protocol.Internal},
|
||||
UserID: users[0].ID,
|
||||
@@ -165,9 +155,8 @@ func seedDatabase() error {
|
||||
})
|
||||
}
|
||||
|
||||
// resetDatabase resets the database by deleting all rows from each table.
|
||||
func resetDatabase() error {
|
||||
err := common.DB.Transaction(func(tx *gorm.DB) error {
|
||||
func (s *TestService) ResetDatabase() error {
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
var tables []string
|
||||
if err := tx.Raw("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' AND name != 'schema_migrations';").Scan(&tables).Error; err != nil {
|
||||
return err
|
||||
@@ -183,13 +172,11 @@ func resetDatabase() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.InitDbConfig()
|
||||
return nil
|
||||
err = s.appConfigService.InitDbConfig()
|
||||
return err
|
||||
}
|
||||
|
||||
// resetApplicationImages resets the application images by removing existing images and replacing them with the default ones
|
||||
func resetApplicationImages() error {
|
||||
|
||||
func (s *TestService) ResetApplicationImages() error {
|
||||
if err := os.RemoveAll(common.EnvConfig.UploadPath); err != nil {
|
||||
log.Printf("Error removing directory: %v", err)
|
||||
return err
|
||||
@@ -204,20 +191,19 @@ func resetApplicationImages() error {
|
||||
}
|
||||
|
||||
// getCborPublicKey decodes a Base64 encoded public key and returns the CBOR encoded COSE key
|
||||
func getCborPublicKey(base64PublicKey string) []byte {
|
||||
func getCborPublicKey(base64PublicKey string) ([]byte, error) {
|
||||
decodedKey, err := base64.StdEncoding.DecodeString(base64PublicKey)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to decode base64 key: %v", err)
|
||||
return nil, fmt.Errorf("failed to decode base64 key: %w", err)
|
||||
}
|
||||
|
||||
pubKey, err := x509.ParsePKIXPublicKey(decodedKey)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse public key: %v", err)
|
||||
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
ecdsaPubKey, ok := pubKey.(*ecdsa.PublicKey)
|
||||
if !ok {
|
||||
log.Fatalf("Not an ECDSA public key")
|
||||
return nil, fmt.Errorf("not an ECDSA public key")
|
||||
}
|
||||
|
||||
coseKey := map[int]interface{}{
|
||||
@@ -230,8 +216,8 @@ func getCborPublicKey(base64PublicKey string) []byte {
|
||||
|
||||
cborPublicKey, err := cbor.Marshal(coseKey)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to encode CBOR: %v", err)
|
||||
return nil, fmt.Errorf("failed to marshal COSE key: %w", err)
|
||||
}
|
||||
|
||||
return cborPublicKey
|
||||
return cborPublicKey, nil
|
||||
}
|
||||
173
backend/internal/service/user_sevice.go
Normal file
173
backend/internal/service/user_sevice.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/dto"
|
||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
db *gorm.DB
|
||||
jwtService *JwtService
|
||||
}
|
||||
|
||||
func NewUserService(db *gorm.DB, jwtService *JwtService) *UserService {
|
||||
return &UserService{db: db, jwtService: jwtService}
|
||||
}
|
||||
|
||||
func (s *UserService) ListUsers(searchTerm string, page int, pageSize int) ([]model.User, utils.PaginationResponse, error) {
|
||||
var users []model.User
|
||||
query := s.db.Model(&model.User{})
|
||||
|
||||
if searchTerm != "" {
|
||||
searchPattern := "%" + searchTerm + "%"
|
||||
query = query.Where("email LIKE ? OR first_name LIKE ? OR username LIKE ?", searchPattern, searchPattern, searchPattern)
|
||||
}
|
||||
|
||||
pagination, err := utils.Paginate(page, pageSize, query, &users)
|
||||
return users, pagination, err
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(userID string) (model.User, error) {
|
||||
var user model.User
|
||||
err := s.db.Where("id = ?", userID).First(&user).Error
|
||||
return user, err
|
||||
}
|
||||
|
||||
func (s *UserService) DeleteUser(userID string) error {
|
||||
var user model.User
|
||||
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.db.Delete(&user).Error
|
||||
}
|
||||
|
||||
func (s *UserService) CreateUser(input dto.UserCreateDto) (model.User, error) {
|
||||
user := &model.User{
|
||||
FirstName: input.FirstName,
|
||||
LastName: input.LastName,
|
||||
Email: input.Email,
|
||||
Username: input.Username,
|
||||
IsAdmin: input.IsAdmin,
|
||||
}
|
||||
if err := s.db.Create(user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return model.User{}, s.checkDuplicatedFields(*user)
|
||||
}
|
||||
return model.User{}, err
|
||||
}
|
||||
return *user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) UpdateUser(userID string, updatedUser dto.UserCreateDto, updateOwnUser bool) (model.User, error) {
|
||||
var user model.User
|
||||
if err := s.db.Where("id = ?", userID).First(&user).Error; err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
user.FirstName = updatedUser.FirstName
|
||||
user.LastName = updatedUser.LastName
|
||||
user.Email = updatedUser.Email
|
||||
user.Username = updatedUser.Username
|
||||
if !updateOwnUser {
|
||||
user.IsAdmin = updatedUser.IsAdmin
|
||||
}
|
||||
|
||||
if err := s.db.Save(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return user, s.checkDuplicatedFields(user)
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) CreateOneTimeAccessToken(userID string, expiresAt time.Time) (string, error) {
|
||||
randomString, err := utils.GenerateRandomAlphanumericString(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
oneTimeAccessToken := model.OneTimeAccessToken{
|
||||
UserID: userID,
|
||||
ExpiresAt: expiresAt,
|
||||
Token: randomString,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&oneTimeAccessToken).Error; err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return oneTimeAccessToken.Token, nil
|
||||
}
|
||||
|
||||
func (s *UserService) ExchangeOneTimeAccessToken(token string) (model.User, string, error) {
|
||||
var oneTimeAccessToken model.OneTimeAccessToken
|
||||
if err := s.db.Where("token = ? AND expires_at > ?", token, utils.FormatDateForDb(time.Now())).Preload("User").First(&oneTimeAccessToken).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return model.User{}, "", common.ErrTokenInvalidOrExpired
|
||||
}
|
||||
return model.User{}, "", err
|
||||
}
|
||||
accessToken, err := s.jwtService.GenerateAccessToken(oneTimeAccessToken.User)
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
if err := s.db.Delete(&oneTimeAccessToken).Error; err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
return oneTimeAccessToken.User, accessToken, nil
|
||||
}
|
||||
|
||||
func (s *UserService) SetupInitialAdmin() (model.User, string, error) {
|
||||
var userCount int64
|
||||
if err := s.db.Model(&model.User{}).Count(&userCount).Error; err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
if userCount > 1 {
|
||||
return model.User{}, "", common.ErrSetupAlreadyCompleted
|
||||
}
|
||||
|
||||
user := model.User{
|
||||
FirstName: "Admin",
|
||||
LastName: "Admin",
|
||||
Username: "admin",
|
||||
Email: "admin@admin.com",
|
||||
IsAdmin: true,
|
||||
}
|
||||
|
||||
if err := s.db.Model(&model.User{}).Preload("Credentials").FirstOrCreate(&user).Error; err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
if len(user.Credentials) > 0 {
|
||||
return model.User{}, "", common.ErrSetupAlreadyCompleted
|
||||
}
|
||||
|
||||
token, err := s.jwtService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
return model.User{}, "", err
|
||||
}
|
||||
|
||||
return user, token, nil
|
||||
}
|
||||
|
||||
func (s *UserService) checkDuplicatedFields(user model.User) error {
|
||||
var existingUser model.User
|
||||
if s.db.Where("id != ? AND email = ?", user.ID, user.Email).First(&existingUser).Error == nil {
|
||||
return common.ErrEmailTaken
|
||||
}
|
||||
|
||||
if s.db.Where("id != ? AND username = ?", user.ID, user.Username).First(&existingUser).Error == nil {
|
||||
return common.ErrUsernameTaken
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
196
backend/internal/service/webauthn_service.go
Normal file
196
backend/internal/service/webauthn_service.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
"github.com/stonith404/pocket-id/backend/internal/common"
|
||||
"github.com/stonith404/pocket-id/backend/internal/model"
|
||||
"github.com/stonith404/pocket-id/backend/internal/utils"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WebAuthnService struct {
|
||||
db *gorm.DB
|
||||
webAuthn *webauthn.WebAuthn
|
||||
}
|
||||
|
||||
func NewWebAuthnService(db *gorm.DB, appConfigService *AppConfigService) *WebAuthnService {
|
||||
webauthnConfig := &webauthn.Config{
|
||||
RPDisplayName: appConfigService.DbConfig.AppName.Value,
|
||||
RPID: utils.GetHostFromURL(common.EnvConfig.AppURL),
|
||||
RPOrigins: []string{common.EnvConfig.AppURL},
|
||||
Timeouts: webauthn.TimeoutsConfig{
|
||||
Login: webauthn.TimeoutConfig{
|
||||
Enforce: true,
|
||||
Timeout: time.Second * 60,
|
||||
TimeoutUVD: time.Second * 60,
|
||||
},
|
||||
Registration: webauthn.TimeoutConfig{
|
||||
Enforce: true,
|
||||
Timeout: time.Second * 60,
|
||||
TimeoutUVD: time.Second * 60,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
wa, _ := webauthn.New(webauthnConfig)
|
||||
return &WebAuthnService{db: db, webAuthn: wa}
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) BeginRegistration(userID string) (*model.PublicKeyCredentialCreationOptions, error) {
|
||||
var user model.User
|
||||
if err := s.db.Preload("Credentials").Find(&user, "id = ?", userID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
options, session, err := s.webAuthn.BeginRegistration(&user, webauthn.WithResidentKeyRequirement(protocol.ResidentKeyRequirementRequired), webauthn.WithExclusions(user.WebAuthnCredentialDescriptors()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessionToStore := &model.WebauthnSession{
|
||||
ExpiresAt: session.Expires,
|
||||
Challenge: session.Challenge,
|
||||
UserVerification: string(session.UserVerification),
|
||||
}
|
||||
|
||||
if err := s.db.Create(&sessionToStore).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.PublicKeyCredentialCreationOptions{
|
||||
Response: options.Response,
|
||||
SessionID: sessionToStore.ID,
|
||||
Timeout: s.webAuthn.Config.Timeouts.Registration.Timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) VerifyRegistration(sessionID, userID string, r *http.Request) (model.WebauthnCredential, error) {
|
||||
var storedSession model.WebauthnSession
|
||||
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
session := webauthn.SessionData{
|
||||
Challenge: storedSession.Challenge,
|
||||
Expires: storedSession.ExpiresAt,
|
||||
UserID: []byte(userID),
|
||||
}
|
||||
|
||||
var user model.User
|
||||
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
credential, err := s.webAuthn.FinishRegistration(&user, session, r)
|
||||
if err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
credentialToStore := model.WebauthnCredential{
|
||||
Name: "New Passkey",
|
||||
CredentialID: string(credential.ID),
|
||||
AttestationType: credential.AttestationType,
|
||||
PublicKey: credential.PublicKey,
|
||||
Transport: credential.Transport,
|
||||
UserID: user.ID,
|
||||
BackupEligible: credential.Flags.BackupEligible,
|
||||
BackupState: credential.Flags.BackupState,
|
||||
}
|
||||
if err := s.db.Create(&credentialToStore).Error; err != nil {
|
||||
return model.WebauthnCredential{}, err
|
||||
}
|
||||
|
||||
return credentialToStore, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) BeginLogin() (*model.PublicKeyCredentialRequestOptions, error) {
|
||||
options, session, err := s.webAuthn.BeginDiscoverableLogin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessionToStore := &model.WebauthnSession{
|
||||
ExpiresAt: session.Expires,
|
||||
Challenge: session.Challenge,
|
||||
UserVerification: string(session.UserVerification),
|
||||
}
|
||||
|
||||
if err := s.db.Create(&sessionToStore).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.PublicKeyCredentialRequestOptions{
|
||||
Response: options.Response,
|
||||
SessionID: sessionToStore.ID,
|
||||
Timeout: s.webAuthn.Config.Timeouts.Registration.Timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) VerifyLogin(sessionID, userID string, credentialAssertionData *protocol.ParsedCredentialAssertionData) (model.User, error) {
|
||||
var storedSession model.WebauthnSession
|
||||
if err := s.db.First(&storedSession, "id = ?", sessionID).Error; err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
session := webauthn.SessionData{
|
||||
Challenge: storedSession.Challenge,
|
||||
Expires: storedSession.ExpiresAt,
|
||||
}
|
||||
|
||||
var user *model.User
|
||||
_, err := s.webAuthn.ValidateDiscoverableLogin(func(_, userHandle []byte) (webauthn.User, error) {
|
||||
if err := s.db.Preload("Credentials").First(&user, "id = ?", string(userHandle)).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user, nil
|
||||
}, session, credentialAssertionData)
|
||||
|
||||
if err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
if err := s.db.Find(&user, "id = ?", userID).Error; err != nil {
|
||||
return model.User{}, err
|
||||
}
|
||||
|
||||
return *user, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) ListCredentials(userID string) ([]model.WebauthnCredential, error) {
|
||||
var credentials []model.WebauthnCredential
|
||||
if err := s.db.Find(&credentials, "user_id = ?", userID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return credentials, nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) DeleteCredential(userID, credentialID string) error {
|
||||
var credential model.WebauthnCredential
|
||||
if err := s.db.First(&credential, "id = ? AND user_id = ?", credentialID, userID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.db.Delete(&credential).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WebAuthnService) UpdateCredential(userID, credentialID, name string) (model.WebauthnCredential, error) {
|
||||
var credential model.WebauthnCredential
|
||||
if err := s.db.Where("id = ? AND user_id = ?", credentialID, userID).First(&credential).Error; err != nil {
|
||||
return credential, err
|
||||
}
|
||||
|
||||
credential.Name = name
|
||||
|
||||
if err := s.db.Save(&credential).Error; err != nil {
|
||||
return credential, err
|
||||
}
|
||||
|
||||
return credential, nil
|
||||
}
|
||||
75
backend/internal/utils/controller_error_util.go
Normal file
75
backend/internal/utils/controller_error_util.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func ControllerError(c *gin.Context, err error) {
|
||||
// Check for record not found errors
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
CustomControllerError(c, http.StatusNotFound, "Record not found")
|
||||
return
|
||||
}
|
||||
|
||||
// Check for validation errors
|
||||
var validationErrors validator.ValidationErrors
|
||||
if errors.As(err, &validationErrors) {
|
||||
message := handleValidationError(validationErrors)
|
||||
CustomControllerError(c, http.StatusBadRequest, message)
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
log.Println(err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
|
||||
}
|
||||
|
||||
func handleValidationError(validationErrors validator.ValidationErrors) string {
|
||||
var errorMessages []string
|
||||
|
||||
for _, ve := range validationErrors {
|
||||
fieldName := ve.Field()
|
||||
var errorMessage string
|
||||
switch ve.Tag() {
|
||||
case "required":
|
||||
errorMessage = fmt.Sprintf("%s is required", fieldName)
|
||||
case "email":
|
||||
errorMessage = fmt.Sprintf("%s must be a valid email address", fieldName)
|
||||
case "username":
|
||||
errorMessage = fmt.Sprintf("%s must contain only lowercase letters, numbers, and underscores", fieldName)
|
||||
case "url":
|
||||
errorMessage = fmt.Sprintf("%s must be a valid URL", fieldName)
|
||||
case "min":
|
||||
errorMessage = fmt.Sprintf("%s must be at least %s characters long", fieldName, ve.Param())
|
||||
case "max":
|
||||
errorMessage = fmt.Sprintf("%s must be at most %s characters long", fieldName, ve.Param())
|
||||
case "urlList":
|
||||
errorMessage = fmt.Sprintf("%s must be a list of valid URLs", fieldName)
|
||||
default:
|
||||
errorMessage = fmt.Sprintf("%s is invalid", fieldName)
|
||||
}
|
||||
|
||||
errorMessages = append(errorMessages, errorMessage)
|
||||
}
|
||||
|
||||
// Join all the error messages into a single string
|
||||
combinedErrors := strings.Join(errorMessages, ", ")
|
||||
|
||||
return combinedErrors
|
||||
}
|
||||
|
||||
func CustomControllerError(c *gin.Context, statusCode int, message string) {
|
||||
// Capitalize the first letter of the message
|
||||
message = strings.ToUpper(message[:1]) + message[1:]
|
||||
c.JSON(statusCode, gin.H{"error": message})
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -71,3 +72,24 @@ func copyFile(srcFilePath, destFilePath string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SaveFile(file *multipart.FileHeader, dst string) error {
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
if err = os.MkdirAll(filepath.Dir(dst), 0o750); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, src)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func UnknownHandlerError(c *gin.Context, err error) {
|
||||
log.Println(err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Something went wrong"})
|
||||
}
|
||||
|
||||
func HandlerError(c *gin.Context, statusCode int, message string) {
|
||||
// Capitalize the first letter of the message
|
||||
message = strings.ToUpper(message[:1]) + message[1:]
|
||||
c.JSON(statusCode, gin.H{"error": message})
|
||||
}
|
||||
@@ -1,9 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type PaginationResponse struct {
|
||||
@@ -12,10 +10,7 @@ type PaginationResponse struct {
|
||||
CurrentPage int `json:"currentPage"`
|
||||
}
|
||||
|
||||
func Paginate(c *gin.Context, db *gorm.DB, result interface{}) (PaginationResponse, error) {
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("limit", "10"))
|
||||
|
||||
func Paginate(page int, pageSize int, db *gorm.DB, result interface{}) (PaginationResponse, error) {
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ CREATE TABLE webauthn_credentials
|
||||
credential_id TEXT NOT NULL UNIQUE,
|
||||
public_key BLOB NOT NULL,
|
||||
attestation_type TEXT NOT NULL,
|
||||
transport TEXT NOT NULL,
|
||||
transport BLOB NOT NULL,
|
||||
user_id TEXT REFERENCES users
|
||||
);
|
||||
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE app_config_variables
|
||||
RENAME TO application_configuration_variables;
|
||||
@@ -0,0 +1,2 @@
|
||||
ALTER TABLE application_configuration_variables
|
||||
RENAME TO app_config_variables;
|
||||
@@ -0,0 +1,23 @@
|
||||
create table oidc_clients
|
||||
(
|
||||
id TEXT not null primary key,
|
||||
created_at DATETIME,
|
||||
name TEXT,
|
||||
secret TEXT,
|
||||
callback_url TEXT,
|
||||
image_type TEXT,
|
||||
created_by_id TEXT
|
||||
references users
|
||||
);
|
||||
|
||||
insert into oidc_clients(id, created_at, name, secret, callback_url, image_type, created_by_id)
|
||||
select id,
|
||||
created_at,
|
||||
name,
|
||||
secret,
|
||||
json_extract(callback_urls, '$[0]'),
|
||||
image_type,
|
||||
created_by_id
|
||||
from oidc_clients_dg_tmp;
|
||||
|
||||
drop table oidc_clients_dg_tmp;
|
||||
@@ -0,0 +1,26 @@
|
||||
create table oidc_clients_dg_tmp
|
||||
(
|
||||
id TEXT not null primary key,
|
||||
created_at DATETIME,
|
||||
name TEXT,
|
||||
secret TEXT,
|
||||
callback_urls BLOB,
|
||||
image_type TEXT,
|
||||
created_by_id TEXT
|
||||
references users
|
||||
);
|
||||
|
||||
insert into oidc_clients_dg_tmp(id, created_at, name, secret, callback_urls, image_type, created_by_id)
|
||||
select id,
|
||||
created_at,
|
||||
name,
|
||||
secret,
|
||||
CAST('["' || callback_url || '"]' AS BLOB),
|
||||
image_type,
|
||||
created_by_id
|
||||
from oidc_clients;
|
||||
|
||||
drop table oidc_clients;
|
||||
|
||||
alter table oidc_clients_dg_tmp
|
||||
rename to oidc_clients;
|
||||
81
docs/proxy-services.md
Normal file
81
docs/proxy-services.md
Normal file
@@ -0,0 +1,81 @@
|
||||
# Proxy Services through Pocket ID
|
||||
|
||||
The goal of Pocket ID is to stay simple. Because of that we don't have a built-in proxy provider. However, you can use [OAuth2 Proxy](https://oauth2-proxy.github.io/) to add authentication to your services that don't support OIDC. This guide will show you how to set up OAuth2 Proxy with Pocket ID.
|
||||
|
||||
## Docker Setup
|
||||
|
||||
#### 1. Add OAuth2 proxy to the service that should be proxied.
|
||||
|
||||
To configure OAuth2 Proxy with Pocket ID, you have to add the following service to the service that should be proxied. E.g., [Uptime Kuma](https://github.com/louislam/uptime-kuma) should be proxied, you can add the following service to the `docker-compose.yml` of Uptime Kuma:
|
||||
|
||||
```yaml
|
||||
# Example with Uptime Kuma
|
||||
# uptime-kuma:
|
||||
# image: louislam/uptime-kuma
|
||||
oauth2-proxy:
|
||||
image: quay.io/oauth2-proxy/oauth2-proxy:v7.6.0
|
||||
command: --config /oauth2-proxy.cfg
|
||||
volumes:
|
||||
- "./oauth2-proxy.cfg:/oauth2-proxy.cfg"
|
||||
ports:
|
||||
- 4180:4180
|
||||
```
|
||||
|
||||
#### 2. Create a new OIDC client in Pocket ID.
|
||||
|
||||
Create a new OIDC client in Pocket ID by navigating to `https://<your-domain>/settings/admin/oidc-clients`. After adding the client, you will obtain the client ID and client secret.
|
||||
|
||||
#### 3. Create a configuration file for OAuth2 Proxy.
|
||||
|
||||
Create a configuration file named `oauth2-proxy.cfg` in the same directory as your `docker-compose.yml` file of the service that should be proxied (e.g. Uptime Kuma). This file will contain the necessary configurations for OAuth2 Proxy to work with Pocket ID.
|
||||
|
||||
Here is the recommend `oauth2-proxy.cfg` configuration:
|
||||
|
||||
```cfg
|
||||
# Replace with your own credentials
|
||||
client_id="client-id-from-pocket-id"
|
||||
client_secret="client-secret-from-pocket-id"
|
||||
oidc_issuer_url="https://<your-pocket-id-domain>"
|
||||
|
||||
# Replace with a secure random string
|
||||
cookie_secret="random-string"
|
||||
|
||||
# Upstream servers (e.g http://uptime-kuma:3001)
|
||||
upstreams="http://<service-to-be-proxied>:<port>"
|
||||
|
||||
# Additional Configuration
|
||||
provider="oidc"
|
||||
scope = "openid email profile"
|
||||
|
||||
# If you are using a reverse proxy in front of OAuth2 Proxy
|
||||
reverse_proxy = true
|
||||
|
||||
# Email domains allowed for authentication
|
||||
email_domains = ["*"]
|
||||
|
||||
# If you are using HTTPS
|
||||
cookie_secure="true"
|
||||
|
||||
# Listen on all interfaces
|
||||
http_address="0.0.0.0:4180"
|
||||
```
|
||||
|
||||
For additional configuration options, refer to the official [OAuth2 Proxy documentation](https://oauth2-proxy.github.io/oauth2-proxy/configuration/overview).
|
||||
|
||||
#### 4. Start the services.
|
||||
|
||||
After creating the configuration file, you can start the services using Docker Compose:
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
#### 5. Access the service.
|
||||
|
||||
You can now access the service through OAuth2 Proxy by visiting `http://localhost:4180`.
|
||||
|
||||
## Standalone Installation
|
||||
|
||||
Setting up OAuth2 Proxy with Pocket ID without Docker is similar to the Docker setup. As the setup depends on your environment, you have to adjust the steps accordingly but is should be similar to the Docker setup.
|
||||
|
||||
You can visit the official [OAuth2 Proxy documentation](https://oauth2-proxy.github.io/oauth2-proxy/installation) for more information.
|
||||
@@ -1 +1,2 @@
|
||||
PUBLIC_APP_URL=http://localhost
|
||||
PUBLIC_APP_URL=http://localhost
|
||||
INTERNAL_BACKEND_URL=http://localhost:8080
|
||||
|
||||
581
frontend/package-lock.json
generated
581
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -12,46 +12,46 @@
|
||||
"format": "prettier --write ."
|
||||
},
|
||||
"devDependencies": {
|
||||
"@playwright/test": "^1.46.0",
|
||||
"@sveltejs/adapter-auto": "^3.0.0",
|
||||
"@sveltejs/adapter-node": "^5.2.0",
|
||||
"@sveltejs/kit": "^2.0.0",
|
||||
"@sveltejs/vite-plugin-svelte": "^3.0.0",
|
||||
"@types/eslint": "^8.56.7",
|
||||
"@playwright/test": "^1.46.1",
|
||||
"@sveltejs/adapter-auto": "^3.2.4",
|
||||
"@sveltejs/adapter-node": "^5.2.2",
|
||||
"@sveltejs/kit": "^2.5.24",
|
||||
"@sveltejs/vite-plugin-svelte": "^3.1.2",
|
||||
"@types/eslint": "^9.6.0",
|
||||
"@types/jsonwebtoken": "^9.0.6",
|
||||
"@types/node": "^22.1.0",
|
||||
"autoprefixer": "^10.4.19",
|
||||
"@types/node": "^22.5.0",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"cbor-js": "^0.1.0",
|
||||
"eslint": "^9.0.0",
|
||||
"eslint": "^9.9.1",
|
||||
"eslint-config-prettier": "^9.1.0",
|
||||
"eslint-plugin-svelte": "^2.36.0",
|
||||
"globals": "^15.0.0",
|
||||
"postcss": "^8.4.38",
|
||||
"prettier": "^3.1.1",
|
||||
"prettier-plugin-svelte": "^3.1.2",
|
||||
"prettier-plugin-tailwindcss": "^0.6.4",
|
||||
"eslint-plugin-svelte": "^2.40.0",
|
||||
"globals": "^15.9.0",
|
||||
"postcss": "^8.4.41",
|
||||
"prettier": "^3.3.3",
|
||||
"prettier-plugin-svelte": "^3.2.6",
|
||||
"prettier-plugin-tailwindcss": "^0.6.6",
|
||||
"svelte": "^5.0.0-next.1",
|
||||
"svelte-check": "^3.6.0",
|
||||
"tailwindcss": "^3.4.4",
|
||||
"tslib": "^2.4.1",
|
||||
"typescript": "^5.0.0",
|
||||
"typescript-eslint": "^8.0.0-alpha.20",
|
||||
"vite": "^5.0.3"
|
||||
"svelte-check": "^3.8.6",
|
||||
"tailwindcss": "^3.4.10",
|
||||
"tslib": "^2.7.0",
|
||||
"typescript": "^5.5.4",
|
||||
"typescript-eslint": "^8.2.0",
|
||||
"vite": "^5.4.2"
|
||||
},
|
||||
"type": "module",
|
||||
"dependencies": {
|
||||
"@simplewebauthn/browser": "^10.0.0",
|
||||
"axios": "^1.7.2",
|
||||
"bits-ui": "^0.21.12",
|
||||
"axios": "^1.7.5",
|
||||
"bits-ui": "^0.21.13",
|
||||
"clsx": "^2.1.1",
|
||||
"crypto": "^1.0.1",
|
||||
"formsnap": "^1.0.1",
|
||||
"jsonwebtoken": "^9.0.2",
|
||||
"lucide-svelte": "^0.399.0",
|
||||
"lucide-svelte": "^0.435.0",
|
||||
"mode-watcher": "^0.4.1",
|
||||
"svelte-sonner": "^0.3.27",
|
||||
"sveltekit-superforms": "^2.16.1",
|
||||
"tailwind-merge": "^2.3.0",
|
||||
"sveltekit-superforms": "^2.17.0",
|
||||
"tailwind-merge": "^2.5.2",
|
||||
"tailwind-variants": "^0.2.1",
|
||||
"zod": "^3.23.8"
|
||||
}
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
import { env } from '$env/dynamic/private';
|
||||
import type { Handle, HandleServerError } from '@sveltejs/kit';
|
||||
import { AxiosError } from 'axios';
|
||||
import jwt from 'jsonwebtoken';
|
||||
|
||||
// Workaround so that we can also import this environment variable into client-side code
|
||||
// If we would directly import $env/dynamic/private into the api-service.ts file, it would throw an error
|
||||
// this is still secure as process will just be undefined in the browser
|
||||
process.env.INTERNAL_BACKEND_URL = env.INTERNAL_BACKEND_URL ?? 'http://localhost:8080';
|
||||
|
||||
export const handle: Handle = async ({ event, resolve }) => {
|
||||
const accessToken = event.cookies.get('access_token');
|
||||
|
||||
|
||||
@@ -2,15 +2,17 @@
|
||||
import { Label } from '$lib/components/ui/label';
|
||||
import type { FormInput } from '$lib/utils/form-util';
|
||||
import type { Snippet } from 'svelte';
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
import { Input } from './ui/input';
|
||||
|
||||
let {
|
||||
input = $bindable(),
|
||||
label,
|
||||
description,
|
||||
children
|
||||
}: {
|
||||
input: FormInput<string | boolean | number>;
|
||||
children,
|
||||
...restProps
|
||||
}: HTMLAttributes<HTMLDivElement> & {
|
||||
input?: FormInput<string | boolean | number>;
|
||||
label: string;
|
||||
description?: string;
|
||||
children?: Snippet;
|
||||
@@ -19,19 +21,19 @@
|
||||
const id = label.toLowerCase().replace(/ /g, '-');
|
||||
</script>
|
||||
|
||||
<div>
|
||||
<div {...restProps}>
|
||||
<Label class="mb-0" for={id}>{label}</Label>
|
||||
{#if description}
|
||||
<p class="text-muted-foreground text-xs mt-1">{description}</p>
|
||||
<p class="text-muted-foreground mt-1 text-xs">{description}</p>
|
||||
{/if}
|
||||
<div class="mt-2">
|
||||
{#if children}
|
||||
{@render children()}
|
||||
{:else}
|
||||
{:else if input}
|
||||
<Input {id} bind:value={input.value} />
|
||||
{/if}
|
||||
{#if input.error}
|
||||
<p class="text-sm text-red-500">{input.error}</p>
|
||||
{#if input?.error}
|
||||
<p class="mt-1 text-sm text-red-500">{input.error}</p>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import { browser } from '$app/environment';
|
||||
import { env } from '$env/dynamic/public';
|
||||
import axios from 'axios';
|
||||
|
||||
abstract class APIService {
|
||||
baseURL: string = '/api';
|
||||
api = axios.create({
|
||||
withCredentials: true
|
||||
});
|
||||
@@ -11,11 +9,11 @@ abstract class APIService {
|
||||
constructor(accessToken?: string) {
|
||||
if (accessToken) {
|
||||
this.api.defaults.headers.common['Authorization'] = `Bearer ${accessToken}`;
|
||||
} else {
|
||||
this.api.defaults.baseURL = '/api';
|
||||
}
|
||||
if (!browser) {
|
||||
this.api.defaults.baseURL = (env.PUBLIC_APP_URL ?? 'http://localhost') + '/api';
|
||||
if (browser) {
|
||||
this.api.defaults.baseURL = '/api';
|
||||
} else {
|
||||
this.api.defaults.baseURL = process?.env?.INTERNAL_BACKEND_URL + '/api';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,26 +1,28 @@
|
||||
import type { OidcClient, OidcClientCreate } from '$lib/types/oidc.type';
|
||||
import type { AuthorizeResponse, OidcClient, OidcClientCreate } from '$lib/types/oidc.type';
|
||||
import type { Paginated, PaginationRequest } from '$lib/types/pagination.type';
|
||||
import APIService from './api-service';
|
||||
|
||||
class OidcService extends APIService {
|
||||
async authorize(clientId: string, scope: string, nonce?: string) {
|
||||
async authorize(clientId: string, scope: string, callbackURL : string, nonce?: string) {
|
||||
const res = await this.api.post('/oidc/authorize', {
|
||||
scope,
|
||||
nonce,
|
||||
callbackURL,
|
||||
clientId
|
||||
});
|
||||
|
||||
return res.data.code as string;
|
||||
return res.data as AuthorizeResponse;
|
||||
}
|
||||
|
||||
async authorizeNewClient(clientId: string, scope: string, nonce?: string) {
|
||||
async authorizeNewClient(clientId: string, scope: string, callbackURL: string, nonce?: string) {
|
||||
const res = await this.api.post('/oidc/authorize/new-client', {
|
||||
scope,
|
||||
nonce,
|
||||
callbackURL,
|
||||
clientId
|
||||
});
|
||||
|
||||
return res.data.code as string;
|
||||
return res.data as AuthorizeResponse;
|
||||
}
|
||||
|
||||
async listClients(search?: string, pagination?: PaginationRequest) {
|
||||
|
||||
@@ -2,7 +2,7 @@ export type OidcClient = {
|
||||
id: string;
|
||||
name: string;
|
||||
logoURL: string;
|
||||
callbackURL: string;
|
||||
callbackURLs: [string, ...string[]];
|
||||
hasLogo: boolean;
|
||||
};
|
||||
|
||||
@@ -11,3 +11,8 @@ export type OidcClientCreate = Omit<OidcClient, 'id' | 'logoURL' | 'hasLogo'>;
|
||||
export type OidcClientCreateWithLogo = OidcClientCreate & {
|
||||
logo: File | null;
|
||||
};
|
||||
|
||||
export type AuthorizeResponse = {
|
||||
code: string;
|
||||
callbackURL: string;
|
||||
};
|
||||
|
||||
@@ -11,6 +11,7 @@ export const load: PageServerLoad = async ({ url, cookies }) => {
|
||||
scope: url.searchParams.get('scope')!,
|
||||
nonce: url.searchParams.get('nonce') || undefined,
|
||||
state: url.searchParams.get('state')!,
|
||||
callbackURL: url.searchParams.get('redirect_uri')!,
|
||||
client
|
||||
};
|
||||
};
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
let authorizationRequired = false;
|
||||
|
||||
export let data: PageData;
|
||||
let { scope, nonce, client, state } = data;
|
||||
let { scope, nonce, client, state, callbackURL } = data;
|
||||
|
||||
async function authorize() {
|
||||
isLoading = true;
|
||||
@@ -36,9 +36,11 @@
|
||||
await webauthnService.finishLogin(authResponse);
|
||||
}
|
||||
|
||||
await oidService.authorize(client!.id, scope, nonce).then(async (code) => {
|
||||
onSuccess(code);
|
||||
});
|
||||
await oidService
|
||||
.authorize(client!.id, scope, callbackURL, nonce)
|
||||
.then(async ({ code, callbackURL }) => {
|
||||
onSuccess(code, callbackURL);
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof AxiosError && e.response?.status === 403) {
|
||||
authorizationRequired = true;
|
||||
@@ -52,19 +54,21 @@
|
||||
async function authorizeNewClient() {
|
||||
isLoading = true;
|
||||
try {
|
||||
await oidService.authorizeNewClient(client!.id, scope, nonce).then(async (code) => {
|
||||
onSuccess(code);
|
||||
});
|
||||
await oidService
|
||||
.authorizeNewClient(client!.id, scope, callbackURL, nonce)
|
||||
.then(async ({ code, callbackURL }) => {
|
||||
onSuccess(code, callbackURL);
|
||||
});
|
||||
} catch (e) {
|
||||
errorMessage = getWebauthnErrorMessage(e);
|
||||
isLoading = false;
|
||||
}
|
||||
}
|
||||
|
||||
function onSuccess(code: string) {
|
||||
function onSuccess(code: string, callbackURL: string) {
|
||||
success = true;
|
||||
setTimeout(() => {
|
||||
window.location.href = `${client!.callbackURL}?code=${code}&state=${state}`;
|
||||
window.location.href = `${callbackURL}?code=${code}&state=${state}`;
|
||||
}, 1000);
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -47,7 +47,6 @@
|
||||
<form onsubmit={onSubmit}>
|
||||
<div class="flex flex-col gap-5">
|
||||
<FormInput label="Application Name" bind:input={$inputs.appName} />
|
||||
|
||||
<FormInput
|
||||
label="Session Duration"
|
||||
description="The duration of a session in minutes before the user has to sign in again."
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
<script lang="ts">
|
||||
import FormInput from '$lib/components/form-input.svelte';
|
||||
import { Button } from '$lib/components/ui/button';
|
||||
import { Input } from '$lib/components/ui/input';
|
||||
import { LucideMinus, LucidePlus } from 'lucide-svelte';
|
||||
import type { Snippet } from 'svelte';
|
||||
import type { HTMLAttributes } from 'svelte/elements';
|
||||
|
||||
let {
|
||||
callbackURLs = $bindable(),
|
||||
error = $bindable(null),
|
||||
...restProps
|
||||
}: HTMLAttributes<HTMLDivElement> & {
|
||||
callbackURLs: string[];
|
||||
error?: string | null;
|
||||
children?: Snippet;
|
||||
} = $props();
|
||||
|
||||
const limit = 5;
|
||||
</script>
|
||||
|
||||
<div {...restProps}>
|
||||
<FormInput label="Callback URLs">
|
||||
<div class="flex flex-col gap-y-2">
|
||||
{#each callbackURLs as _, i}
|
||||
<div class="flex gap-x-2">
|
||||
<Input data-testid={`callback-url-${i + 1}`} bind:value={callbackURLs[i]} />
|
||||
{#if callbackURLs.length > 1}
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
on:click={() => callbackURLs = callbackURLs.filter((_, index) => index !== i)}
|
||||
>
|
||||
<LucideMinus class="h-4 w-4" />
|
||||
</Button>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</FormInput>
|
||||
{#if error}
|
||||
<p class="mt-1 text-sm text-red-500">{error}</p>
|
||||
{/if}
|
||||
{#if callbackURLs.length < limit}
|
||||
<Button
|
||||
class="mt-2"
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
on:click={() => callbackURLs = [...callbackURLs, '']}
|
||||
>
|
||||
<LucidePlus class="mr-1 h-4 w-4" />
|
||||
Add another
|
||||
</Button>
|
||||
{/if}
|
||||
</div>
|
||||
@@ -10,6 +10,7 @@
|
||||
} from '$lib/types/oidc.type';
|
||||
import { createForm } from '$lib/utils/form-util';
|
||||
import { z } from 'zod';
|
||||
import OidcCallbackUrlInput from './oidc-callback-url-input.svelte';
|
||||
|
||||
let {
|
||||
callback,
|
||||
@@ -27,12 +28,12 @@
|
||||
|
||||
const client: OidcClientCreate = {
|
||||
name: existingClient?.name || '',
|
||||
callbackURL: existingClient?.callbackURL || ''
|
||||
callbackURLs: existingClient?.callbackURLs || [""]
|
||||
};
|
||||
|
||||
const formSchema = z.object({
|
||||
name: z.string().min(2).max(50),
|
||||
callbackURL: z.string().url()
|
||||
callbackURLs: z.array(z.string().url()).nonempty()
|
||||
});
|
||||
|
||||
type FormSchema = typeof formSchema;
|
||||
@@ -70,32 +71,40 @@
|
||||
</script>
|
||||
|
||||
<form onsubmit={onSubmit}>
|
||||
<div class="mt-3 grid grid-cols-2 gap-3">
|
||||
<FormInput label="Name" bind:input={$inputs.name} />
|
||||
<FormInput label="Callback URL" bind:input={$inputs.callbackURL} />
|
||||
<div class="mt-3">
|
||||
<Label for="logo">Logo</Label>
|
||||
<div class="mt-2 flex items-end gap-3">
|
||||
{#if logoDataURL}
|
||||
<div class="h-32 w-32 rounded-2xl bg-muted p-3">
|
||||
<img class="m-auto max-h-full max-w-full object-contain" src={logoDataURL} alt={`${$inputs.name.value} logo`} />
|
||||
</div>
|
||||
{/if}
|
||||
<div class="flex flex-col gap-2">
|
||||
<FileInput
|
||||
id="logo"
|
||||
variant="secondary"
|
||||
accept="image/png, image/jpeg, image/svg+xml"
|
||||
onchange={onLogoChange}
|
||||
>
|
||||
<Button variant="secondary">
|
||||
{existingClient?.hasLogo ? 'Change Logo' : 'Upload Logo'}
|
||||
</Button>
|
||||
</FileInput>
|
||||
{#if logoDataURL}
|
||||
<Button variant="outline" on:click={resetLogo}>Remove Logo</Button>
|
||||
{/if}
|
||||
<div class="flex flex-col gap-3 sm:flex-row">
|
||||
<FormInput label="Name" class="w-full" bind:input={$inputs.name} />
|
||||
<OidcCallbackUrlInput
|
||||
class="w-full"
|
||||
bind:callbackURLs={$inputs.callbackURLs.value}
|
||||
bind:error={$inputs.callbackURLs.error}
|
||||
/>
|
||||
</div>
|
||||
<div class="mt-3">
|
||||
<Label for="logo">Logo</Label>
|
||||
<div class="mt-2 flex items-end gap-3">
|
||||
{#if logoDataURL}
|
||||
<div class="bg-muted h-32 w-32 rounded-2xl p-3">
|
||||
<img
|
||||
class="m-auto max-h-full max-w-full object-contain"
|
||||
src={logoDataURL}
|
||||
alt={`${$inputs.name.value} logo`}
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
<div class="flex flex-col gap-2">
|
||||
<FileInput
|
||||
id="logo"
|
||||
variant="secondary"
|
||||
accept="image/png, image/jpeg, image/svg+xml"
|
||||
onchange={onLogoChange}
|
||||
>
|
||||
<Button variant="secondary">
|
||||
{existingClient?.hasLogo ? 'Change Logo' : 'Upload Logo'}
|
||||
</Button>
|
||||
</FileInput>
|
||||
{#if logoDataURL}
|
||||
<Button variant="outline" on:click={resetLogo}>Remove Logo</Button>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -26,9 +26,13 @@
|
||||
};
|
||||
|
||||
const formSchema = z.object({
|
||||
firstName: z.string().min(2).max(50),
|
||||
lastName: z.string().min(2).max(50),
|
||||
username: z.string().min(2).max(50),
|
||||
firstName: z.string().min(2).max(30),
|
||||
lastName: z.string().min(2).max(30),
|
||||
username: z
|
||||
.string()
|
||||
.min(2)
|
||||
.max(30)
|
||||
.regex(/^[a-z0-9_]+$/, 'Only lowercase letters, numbers, and underscores are allowed'),
|
||||
email: z.string().email(),
|
||||
isAdmin: z.boolean()
|
||||
});
|
||||
@@ -66,10 +70,10 @@
|
||||
<div class="items-top mt-5 flex space-x-2">
|
||||
<Checkbox id="admin-privileges" bind:checked={$inputs.isAdmin.value} />
|
||||
<div class="grid gap-1.5 leading-none">
|
||||
<Label for="admin-privileges" class="text-sm font-medium leading-none mb-0">
|
||||
<Label for="admin-privileges" class="mb-0 text-sm font-medium leading-none">
|
||||
Admin Privileges
|
||||
</Label>
|
||||
<p class="text-[0.8rem] text-muted-foreground">Admins have full access to the admin panel.</p>
|
||||
<p class="text-muted-foreground text-[0.8rem]">Admins have full access to the admin panel.</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="mt-5 flex justify-end">
|
||||
|
||||
@@ -69,15 +69,3 @@ test('Delete passkey from account', async ({ page }) => {
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('Passkey deleted successfully');
|
||||
});
|
||||
|
||||
test('Delete last passkey from account fails', async ({ page }) => {
|
||||
await page.goto('/settings/account');
|
||||
|
||||
await page.getByLabel('Delete').first().click();
|
||||
await page.getByText('Delete', { exact: true }).click();
|
||||
|
||||
await page.getByLabel('Delete').first().click();
|
||||
await page.getByText('Delete', { exact: true }).click();
|
||||
|
||||
await expect(page.getByRole('status').first()).toHaveText('You must have at least one passkey');
|
||||
});
|
||||
|
||||
@@ -10,10 +10,15 @@ test('Update general configuration', async ({ page }) => {
|
||||
await page.getByLabel('Session Duration').fill('30');
|
||||
await page.getByRole('button', { name: 'Save' }).first().click();
|
||||
|
||||
await expect(page.getByTestId('application-name')).toHaveText('Updated Name');
|
||||
await expect(page.getByRole('status')).toHaveText(
|
||||
'Application configuration updated successfully'
|
||||
);
|
||||
await expect(page.getByTestId('application-name')).toHaveText('Updated Name');
|
||||
|
||||
await page.reload();
|
||||
|
||||
await expect(page.getByLabel('Name')).toHaveValue('Updated Name');
|
||||
await expect(page.getByLabel('Session Duration')).toHaveValue('30');
|
||||
});
|
||||
|
||||
test('Update application images', async ({ page }) => {
|
||||
|
||||
@@ -23,17 +23,18 @@ export const users = {
|
||||
|
||||
export const oidcClients = {
|
||||
nextcloud: {
|
||||
id: "3654a746-35d4-4321-ac61-0bdcff2b4055",
|
||||
id: '3654a746-35d4-4321-ac61-0bdcff2b4055',
|
||||
name: 'Nextcloud',
|
||||
callbackUrl: 'http://nextcloud/auth/callback'
|
||||
},
|
||||
immich: {
|
||||
id: "606c7782-f2b1-49e5-8ea9-26eb1b06d018",
|
||||
name: 'Immich',
|
||||
callbackUrl: 'http://immich/auth/callback'
|
||||
},
|
||||
immich: {
|
||||
id: '606c7782-f2b1-49e5-8ea9-26eb1b06d018',
|
||||
name: 'Immich',
|
||||
callbackUrl: 'http://immich/auth/callback'
|
||||
},
|
||||
pingvinShare: {
|
||||
name: 'Pingvin Share',
|
||||
callbackUrl: 'http://pingvin.share/auth/callback'
|
||||
callbackUrl: 'http://pingvin.share/auth/callback',
|
||||
secondCallbackUrl: 'http://pingvin.share/auth/callback2'
|
||||
}
|
||||
};
|
||||
|
||||
@@ -10,7 +10,11 @@ test('Create OIDC client', async ({ page }) => {
|
||||
|
||||
await page.getByRole('button', { name: 'Add OIDC Client' }).click();
|
||||
await page.getByLabel('Name').fill(oidcClient.name);
|
||||
await page.getByLabel('Callback URL').fill(oidcClient.callbackUrl);
|
||||
|
||||
await page.getByTestId('callback-url-1').fill(oidcClient.callbackUrl);
|
||||
await page.getByRole('button', { name: 'Add another' }).click();
|
||||
await page.getByTestId('callback-url-2').fill(oidcClient.secondCallbackUrl!);
|
||||
|
||||
await page.getByLabel('logo').setInputFiles('tests/assets/pingvin-share-logo.png');
|
||||
await page.getByRole('button', { name: 'Save' }).click();
|
||||
|
||||
@@ -20,7 +24,8 @@ test('Create OIDC client', async ({ page }) => {
|
||||
expect(clientId?.length).toBe(36);
|
||||
expect((await page.getByTestId('client-secret').textContent())?.length).toBe(32);
|
||||
await expect(page.getByLabel('Name')).toHaveValue(oidcClient.name);
|
||||
await expect(page.getByLabel('Callback URL')).toHaveValue(oidcClient.callbackUrl);
|
||||
await expect(page.getByTestId('callback-url-1')).toHaveValue(oidcClient.callbackUrl);
|
||||
await expect(page.getByTestId('callback-url-2')).toHaveValue(oidcClient.secondCallbackUrl!);
|
||||
await expect(page.getByRole('img', { name: `${oidcClient.name} logo` })).toBeVisible();
|
||||
await page.request
|
||||
.get(`/api/oidc/clients/${clientId}/logo`)
|
||||
@@ -32,7 +37,7 @@ test('Edit OIDC client', async ({ page }) => {
|
||||
await page.goto(`/settings/admin/oidc-clients/${oidcClient.id}`);
|
||||
|
||||
await page.getByLabel('Name').fill('Nextcloud updated');
|
||||
await page.getByLabel('Callback URL').fill('http://nextcloud-updated/auth/callback');
|
||||
await page.getByTestId('callback-url-1').fill('http://nextcloud-updated/auth/callback');
|
||||
await page.getByLabel('logo').setInputFiles('tests/assets/nextcloud-logo.png');
|
||||
await page.getByRole('button', { name: 'Save' }).click();
|
||||
|
||||
|
||||
@@ -35,6 +35,21 @@ test('Create user fails with already taken email', async ({ page }) => {
|
||||
await expect(page.getByRole('status')).toHaveText('Email is already taken');
|
||||
});
|
||||
|
||||
test('Create user fails with already taken username', async ({ page }) => {
|
||||
const user = users.steve;
|
||||
|
||||
await page.goto('/settings/admin/users');
|
||||
|
||||
await page.getByRole('button', { name: 'Add User' }).click();
|
||||
await page.getByLabel('Firstname').fill(user.firstname);
|
||||
await page.getByLabel('Lastname').fill(user.lastname);
|
||||
await page.getByLabel('Email').fill(user.email);
|
||||
await page.getByLabel('Username').fill(users.tim.username);
|
||||
await page.getByRole('button', { name: 'Save' }).click();
|
||||
|
||||
await expect(page.getByRole('status')).toHaveText('Username is already taken');
|
||||
});
|
||||
|
||||
test('Create one time access token', async ({ page }) => {
|
||||
await page.goto('/settings/admin/users');
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import axios from 'axios';
|
||||
|
||||
export async function cleanupBackend() {
|
||||
await axios.post('/api/test/reset');
|
||||
await axios.post('http://localhost/api/test/reset');
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user