mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-16 01:11:16 +03:00
Compare commits
75 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2dfb3da5d | ||
|
|
cbf0e3117d | ||
|
|
694f266dea | ||
|
|
29fc185376 | ||
|
|
781be37416 | ||
|
|
b1f97e05a1 | ||
|
|
2c74865173 | ||
|
|
ad8a90c839 | ||
|
|
f9839a978c | ||
|
|
b81de45166 | ||
|
|
22f4254932 | ||
|
|
507f9490fa | ||
|
|
043cce615d | ||
|
|
69e2083722 | ||
|
|
d47b20326f | ||
|
|
fc9939d1f1 | ||
|
|
2c1c67b5e4 | ||
|
|
d010be4c88 | ||
|
|
01db8c0a46 | ||
|
|
fe5917d96d | ||
|
|
4f0b434c54 | ||
|
|
6bdf5fa37a | ||
|
|
47bd5ba1ba | ||
|
|
b746ac0835 | ||
|
|
79989fb176 | ||
|
|
ecc7e224e9 | ||
|
|
549d219f44 | ||
|
|
ffe18db2fb | ||
|
|
e8b172f1c3 | ||
|
|
097bda349a | ||
|
|
6e24517197 | ||
|
|
a3da943aa6 | ||
|
|
cc34aca2a0 | ||
|
|
fde4e9b38a | ||
|
|
c55143d8c9 | ||
|
|
8973e93cb6 | ||
|
|
8c9cac2655 | ||
|
|
ed8547ccc1 | ||
|
|
e7e53a8b8c | ||
|
|
02249491f8 | ||
|
|
cf0892922b | ||
|
|
99f31a7c26 | ||
|
|
68373604dd | ||
|
|
2d6d5df0e7 | ||
|
|
a897b31166 | ||
|
|
fb92906c3a | ||
|
|
c018f29ad7 | ||
|
|
5367463239 | ||
|
|
6c9147483c | ||
|
|
d123d7f335 | ||
|
|
da8ca08c36 | ||
|
|
307caaa3ef | ||
|
|
6c696b46c8 | ||
|
|
42155238b7 | ||
|
|
92edc26a30 | ||
|
|
e36499c483 | ||
|
|
6215e1ac01 | ||
|
|
74b39e16f9 | ||
|
|
a1d8538c64 | ||
|
|
1d7cbc2a4e | ||
|
|
954fb4f0c8 | ||
|
|
901333f7e4 | ||
|
|
0b381467ca | ||
|
|
6188dc6fb7 | ||
|
|
802754c24c | ||
|
|
6c843228eb | ||
|
|
a3979f63e0 | ||
|
|
52c560c30d | ||
|
|
e88be7e61a | ||
|
|
a4e965434f | ||
|
|
096d214a88 | ||
|
|
afb7fc32e7 | ||
|
|
641bbc9351 | ||
|
|
136c6082f6 | ||
|
|
b9a20d2923 |
4
.github/workflows/backend-linter.yml
vendored
4
.github/workflows/backend-linter.yml
vendored
@@ -24,10 +24,10 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version-file: backend/go.mod
|
go-version-file: backend/go.mod
|
||||||
|
|
||||||
|
|||||||
20
.github/workflows/build-next.yml
vendored
20
.github/workflows/build-next.yml
vendored
@@ -19,24 +19,20 @@ jobs:
|
|||||||
attestations: write
|
attestations: write
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@v4
|
uses: pnpm/action-setup@v4
|
||||||
with:
|
|
||||||
version: 10
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
- name: Setup Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v5
|
||||||
with:
|
with:
|
||||||
node-version: 22
|
node-version: 22
|
||||||
cache: 'pnpm'
|
|
||||||
cache-dependency-path: pnpm-lock.yaml
|
|
||||||
|
|
||||||
- name: Setup Go
|
- name: Setup Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version-file: 'backend/go.mod'
|
go-version-file: "backend/go.mod"
|
||||||
|
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v3
|
uses: docker/setup-qemu-action@v3
|
||||||
@@ -76,7 +72,7 @@ jobs:
|
|||||||
platforms: linux/amd64,linux/arm64
|
platforms: linux/amd64,linux/arm64
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ env.DOCKER_IMAGE_NAME }}:next
|
tags: ${{ env.DOCKER_IMAGE_NAME }}:next
|
||||||
file: Dockerfile-prebuilt
|
file: docker/Dockerfile-prebuilt
|
||||||
- name: Build and push container image (distroless)
|
- name: Build and push container image (distroless)
|
||||||
uses: docker/build-push-action@v6
|
uses: docker/build-push-action@v6
|
||||||
id: container-build-push-distroless
|
id: container-build-push-distroless
|
||||||
@@ -85,16 +81,16 @@ jobs:
|
|||||||
platforms: linux/amd64,linux/arm64
|
platforms: linux/amd64,linux/arm64
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ env.DOCKER_IMAGE_NAME }}:next-distroless
|
tags: ${{ env.DOCKER_IMAGE_NAME }}:next-distroless
|
||||||
file: Dockerfile-distroless
|
file: docker/Dockerfile-distroless
|
||||||
- name: Container image attestation
|
- name: Container image attestation
|
||||||
uses: actions/attest-build-provenance@v2
|
uses: actions/attest-build-provenance@v2
|
||||||
with:
|
with:
|
||||||
subject-name: '${{ env.DOCKER_IMAGE_NAME }}'
|
subject-name: "${{ env.DOCKER_IMAGE_NAME }}"
|
||||||
subject-digest: ${{ steps.build-push-image.outputs.digest }}
|
subject-digest: ${{ steps.build-push-image.outputs.digest }}
|
||||||
push-to-registry: true
|
push-to-registry: true
|
||||||
- name: Container image attestation (distroless)
|
- name: Container image attestation (distroless)
|
||||||
uses: actions/attest-build-provenance@v2
|
uses: actions/attest-build-provenance@v2
|
||||||
with:
|
with:
|
||||||
subject-name: '${{ env.DOCKER_IMAGE_NAME }}'
|
subject-name: "${{ env.DOCKER_IMAGE_NAME }}"
|
||||||
subject-digest: ${{ steps.container-build-push-distroless.outputs.digest }}
|
subject-digest: ${{ steps.container-build-push-distroless.outputs.digest }}
|
||||||
push-to-registry: true
|
push-to-registry: true
|
||||||
|
|||||||
13
.github/workflows/e2e-tests.yml
vendored
13
.github/workflows/e2e-tests.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
|||||||
actions: write
|
actions: write
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
@@ -30,6 +30,8 @@ jobs:
|
|||||||
- name: Build and export
|
- name: Build and export
|
||||||
uses: docker/build-push-action@v6
|
uses: docker/build-push-action@v6
|
||||||
with:
|
with:
|
||||||
|
context: .
|
||||||
|
file: docker/Dockerfile
|
||||||
push: false
|
push: false
|
||||||
load: false
|
load: false
|
||||||
tags: pocket-id:test
|
tags: pocket-id:test
|
||||||
@@ -57,18 +59,15 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
db: [sqlite, postgres]
|
db: [sqlite, postgres]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@v4
|
uses: pnpm/action-setup@v4
|
||||||
with:
|
|
||||||
version: 10
|
|
||||||
|
|
||||||
- uses: actions/setup-node@v4
|
- name: Setup Node.js
|
||||||
|
uses: actions/setup-node@v5
|
||||||
with:
|
with:
|
||||||
node-version: 22
|
node-version: 22
|
||||||
cache: "pnpm"
|
|
||||||
cache-dependency-path: pnpm-lock.yaml
|
|
||||||
|
|
||||||
- name: Cache Playwright Browsers
|
- name: Cache Playwright Browsers
|
||||||
uses: actions/cache@v3
|
uses: actions/cache@v3
|
||||||
|
|||||||
25
.github/workflows/release.yml
vendored
25
.github/workflows/release.yml
vendored
@@ -3,7 +3,7 @@ name: Release
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- 'v*.*.*'
|
- "v*.*.*"
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
@@ -18,17 +18,13 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@v4
|
uses: pnpm/action-setup@v4
|
||||||
with:
|
|
||||||
version: 10
|
|
||||||
- name: Setup Node.js
|
- name: Setup Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v5
|
||||||
with:
|
with:
|
||||||
node-version: 22
|
node-version: 22
|
||||||
cache: 'pnpm'
|
- uses: actions/setup-go@v6
|
||||||
cache-dependency-path: pnpm-lock.yaml
|
|
||||||
- uses: actions/setup-go@v5
|
|
||||||
with:
|
with:
|
||||||
go-version-file: 'backend/go.mod'
|
go-version-file: "backend/go.mod"
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@v3
|
uses: docker/setup-qemu-action@v3
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
@@ -71,6 +67,7 @@ jobs:
|
|||||||
run: pnpm --filter pocket-id-frontend install --frozen-lockfile
|
run: pnpm --filter pocket-id-frontend install --frozen-lockfile
|
||||||
- name: Build frontend
|
- name: Build frontend
|
||||||
run: pnpm --filter pocket-id-frontend build
|
run: pnpm --filter pocket-id-frontend build
|
||||||
|
|
||||||
- name: Build binaries
|
- name: Build binaries
|
||||||
run: sh scripts/development/build-binaries.sh
|
run: sh scripts/development/build-binaries.sh
|
||||||
- name: Build and push container image
|
- name: Build and push container image
|
||||||
@@ -82,7 +79,7 @@ jobs:
|
|||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta.outputs.tags }}
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
labels: ${{ steps.meta.outputs.labels }}
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
file: Dockerfile-prebuilt
|
file: docker/Dockerfile-prebuilt
|
||||||
- name: Build and push container image (distroless)
|
- name: Build and push container image (distroless)
|
||||||
uses: docker/build-push-action@v6
|
uses: docker/build-push-action@v6
|
||||||
id: container-build-push-distroless
|
id: container-build-push-distroless
|
||||||
@@ -92,21 +89,21 @@ jobs:
|
|||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta-distroless.outputs.tags }}
|
tags: ${{ steps.meta-distroless.outputs.tags }}
|
||||||
labels: ${{ steps.meta-distroless.outputs.labels }}
|
labels: ${{ steps.meta-distroless.outputs.labels }}
|
||||||
file: Dockerfile-distroless
|
file: docker/Dockerfile-distroless
|
||||||
- name: Binary attestation
|
- name: Binary attestation
|
||||||
uses: actions/attest-build-provenance@v2
|
uses: actions/attest-build-provenance@v2
|
||||||
with:
|
with:
|
||||||
subject-path: 'backend/.bin/pocket-id-**'
|
subject-path: "backend/.bin/pocket-id-**"
|
||||||
- name: Container image attestation
|
- name: Container image attestation
|
||||||
uses: actions/attest-build-provenance@v2
|
uses: actions/attest-build-provenance@v2
|
||||||
with:
|
with:
|
||||||
subject-name: '${{ env.DOCKER_IMAGE_NAME }}'
|
subject-name: "${{ env.DOCKER_IMAGE_NAME }}"
|
||||||
subject-digest: ${{ steps.container-build-push.outputs.digest }}
|
subject-digest: ${{ steps.container-build-push.outputs.digest }}
|
||||||
push-to-registry: true
|
push-to-registry: true
|
||||||
- name: Container image attestation (distroless)
|
- name: Container image attestation (distroless)
|
||||||
uses: actions/attest-build-provenance@v2
|
uses: actions/attest-build-provenance@v2
|
||||||
with:
|
with:
|
||||||
subject-name: '${{ env.DOCKER_IMAGE_NAME }}'
|
subject-name: "${{ env.DOCKER_IMAGE_NAME }}"
|
||||||
subject-digest: ${{ steps.container-build-push-distroless.outputs.digest }}
|
subject-digest: ${{ steps.container-build-push-distroless.outputs.digest }}
|
||||||
push-to-registry: true
|
push-to-registry: true
|
||||||
- name: Upload binaries to release
|
- name: Upload binaries to release
|
||||||
@@ -123,6 +120,6 @@ jobs:
|
|||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
- name: Mark release as published
|
- name: Mark release as published
|
||||||
run: gh release edit ${{ github.ref_name }} --draft=false
|
run: gh release edit ${{ github.ref_name }} --draft=false
|
||||||
|
|||||||
32
.github/workflows/svelte-check.yml
vendored
32
.github/workflows/svelte-check.yml
vendored
@@ -4,21 +4,21 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
paths:
|
paths:
|
||||||
- 'frontend/src/**'
|
- "frontend/src/**"
|
||||||
- '.github/svelte-check-matcher.json'
|
- ".github/svelte-check-matcher.json"
|
||||||
- 'frontend/package.json'
|
- "frontend/package.json"
|
||||||
- 'frontend/package-lock.json'
|
- "frontend/package-lock.json"
|
||||||
- 'frontend/tsconfig.json'
|
- "frontend/tsconfig.json"
|
||||||
- 'frontend/svelte.config.js'
|
- "frontend/svelte.config.js"
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
paths:
|
paths:
|
||||||
- 'frontend/src/**'
|
- "frontend/src/**"
|
||||||
- '.github/svelte-check-matcher.json'
|
- ".github/svelte-check-matcher.json"
|
||||||
- 'frontend/package.json'
|
- "frontend/package.json"
|
||||||
- 'frontend/package-lock.json'
|
- "frontend/package-lock.json"
|
||||||
- 'frontend/tsconfig.json'
|
- "frontend/tsconfig.json"
|
||||||
- 'frontend/svelte.config.js'
|
- "frontend/svelte.config.js"
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
@@ -34,19 +34,15 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Setup pnpm
|
- name: Setup pnpm
|
||||||
uses: pnpm/action-setup@v4
|
uses: pnpm/action-setup@v4
|
||||||
with:
|
|
||||||
version: 10
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
- name: Setup Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v5
|
||||||
with:
|
with:
|
||||||
node-version: 22
|
node-version: 22
|
||||||
cache: 'pnpm'
|
|
||||||
cache-dependency-path: pnpm-lock.yaml
|
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pnpm --filter pocket-id-frontend install --frozen-lockfile
|
run: pnpm --filter pocket-id-frontend install --frozen-lockfile
|
||||||
|
|||||||
4
.github/workflows/unit-tests.yml
vendored
4
.github/workflows/unit-tests.yml
vendored
@@ -16,8 +16,8 @@ jobs:
|
|||||||
actions: write
|
actions: write
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v5
|
||||||
- uses: actions/setup-go@v5
|
- uses: actions/setup-go@v6
|
||||||
with:
|
with:
|
||||||
go-version-file: "backend/go.mod"
|
go-version-file: "backend/go.mod"
|
||||||
cache-dependency-path: "backend/go.sum"
|
cache-dependency-path: "backend/go.sum"
|
||||||
|
|||||||
2
.github/workflows/update-aaguids.yml
vendored
2
.github/workflows/update-aaguids.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Fetch JSON data
|
- name: Fetch JSON data
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
1880
CHANGELOG.md
1880
CHANGELOG.md
File diff suppressed because it is too large
Load Diff
@@ -61,4 +61,4 @@ formatters:
|
|||||||
paths:
|
paths:
|
||||||
- third_party$
|
- third_party$
|
||||||
- builtin$
|
- builtin$
|
||||||
- examples$
|
- examples$
|
||||||
@@ -3,8 +3,10 @@
|
|||||||
package frontend
|
package frontend
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -12,11 +14,55 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed all:dist/*
|
//go:embed all:dist/*
|
||||||
var frontendFS embed.FS
|
var frontendFS embed.FS
|
||||||
|
|
||||||
|
// This function, created by the init() method, writes to "w" the index.html page, populating the nonce
|
||||||
|
var writeIndexFn func(w io.Writer, nonce string) error
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
const scriptTag = "<script>"
|
||||||
|
|
||||||
|
// Read the index.html from the bundle
|
||||||
|
index, iErr := fs.ReadFile(frontendFS, "dist/index.html")
|
||||||
|
if iErr != nil {
|
||||||
|
panic(fmt.Errorf("failed to read index.html: %w", iErr))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the position of the first <script> tag
|
||||||
|
idx := bytes.Index(index, []byte(scriptTag))
|
||||||
|
|
||||||
|
// Create writeIndexFn, which adds the CSP tag to the script tag if needed
|
||||||
|
writeIndexFn = func(w io.Writer, nonce string) (err error) {
|
||||||
|
// If there's no nonce, write the index as-is
|
||||||
|
if nonce == "" {
|
||||||
|
_, err = w.Write(index)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// We have a nonce, so first write the index until the <script> tag
|
||||||
|
// Then we write the modified script tag
|
||||||
|
// Finally, the rest of the index
|
||||||
|
_, err = w.Write(index[0:idx])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = w.Write([]byte(`<script nonce="` + nonce + `">`))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = w.Write(index[(idx + len(scriptTag)):])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func RegisterFrontend(router *gin.Engine) error {
|
func RegisterFrontend(router *gin.Engine) error {
|
||||||
distFS, err := fs.Sub(frontendFS, "dist")
|
distFS, err := fs.Sub(frontendFS, "dist")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -27,13 +73,39 @@ func RegisterFrontend(router *gin.Engine) error {
|
|||||||
fileServer := NewFileServerWithCaching(http.FS(distFS), int(cacheMaxAge.Seconds()))
|
fileServer := NewFileServerWithCaching(http.FS(distFS), int(cacheMaxAge.Seconds()))
|
||||||
|
|
||||||
router.NoRoute(func(c *gin.Context) {
|
router.NoRoute(func(c *gin.Context) {
|
||||||
// Try to serve the requested file
|
|
||||||
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
path := strings.TrimPrefix(c.Request.URL.Path, "/")
|
||||||
if _, err := fs.Stat(distFS, path); os.IsNotExist(err) {
|
|
||||||
// File doesn't exist, serve index.html instead
|
if strings.HasPrefix(path, "api/") {
|
||||||
c.Request.URL.Path = "/"
|
c.JSON(http.StatusNotFound, gin.H{"error": "API endpoint not found"})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If path is / or does not exist, serve index.html
|
||||||
|
if path == "" {
|
||||||
|
path = "index.html"
|
||||||
|
} else if _, err := fs.Stat(distFS, path); os.IsNotExist(err) {
|
||||||
|
path = "index.html"
|
||||||
|
}
|
||||||
|
|
||||||
|
if path == "index.html" {
|
||||||
|
nonce := middleware.GetCSPNonce(c)
|
||||||
|
|
||||||
|
// Do not cache the HTML shell, as it embeds a per-request nonce
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.Header("Cache-Control", "no-store")
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
|
||||||
|
err = writeIndexFn(c.Writer, nonce)
|
||||||
|
if err != nil {
|
||||||
|
_ = c.Error(fmt.Errorf("failed to write index.html file: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve other static assets with caching
|
||||||
|
c.Request.URL.Path = "/" + path
|
||||||
fileServer.ServeHTTP(c.Writer, c.Request)
|
fileServer.ServeHTTP(c.Writer, c.Request)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ require (
|
|||||||
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21
|
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21
|
||||||
github.com/emersion/go-smtp v0.21.3
|
github.com/emersion/go-smtp v0.21.3
|
||||||
github.com/fxamacker/cbor/v2 v2.9.0
|
github.com/fxamacker/cbor/v2 v2.9.0
|
||||||
|
github.com/gin-contrib/slog v1.1.0
|
||||||
github.com/gin-gonic/gin v1.10.1
|
github.com/gin-gonic/gin v1.10.1
|
||||||
github.com/glebarez/go-sqlite v1.22.0
|
github.com/glebarez/go-sqlite v1.22.0
|
||||||
github.com/glebarez/sqlite v1.11.0
|
github.com/glebarez/sqlite v1.11.0
|
||||||
@@ -29,7 +30,6 @@ require (
|
|||||||
github.com/mileusna/useragent v1.3.5
|
github.com/mileusna/useragent v1.3.5
|
||||||
github.com/orandin/slog-gorm v1.4.0
|
github.com/orandin/slog-gorm v1.4.0
|
||||||
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.8
|
github.com/oschwald/maxminddb-golang/v2 v2.0.0-beta.8
|
||||||
github.com/samber/slog-gin v1.15.1
|
|
||||||
github.com/spf13/cobra v1.9.1
|
github.com/spf13/cobra v1.9.1
|
||||||
github.com/stretchr/testify v1.10.0
|
github.com/stretchr/testify v1.10.0
|
||||||
go.opentelemetry.io/contrib/bridges/otelslog v0.12.0
|
go.opentelemetry.io/contrib/bridges/otelslog v0.12.0
|
||||||
@@ -45,6 +45,7 @@ require (
|
|||||||
go.opentelemetry.io/otel/trace v1.37.0
|
go.opentelemetry.io/otel/trace v1.37.0
|
||||||
golang.org/x/crypto v0.41.0
|
golang.org/x/crypto v0.41.0
|
||||||
golang.org/x/image v0.30.0
|
golang.org/x/image v0.30.0
|
||||||
|
golang.org/x/sync v0.16.0
|
||||||
golang.org/x/text v0.28.0
|
golang.org/x/text v0.28.0
|
||||||
golang.org/x/time v0.12.0
|
golang.org/x/time v0.12.0
|
||||||
gorm.io/driver/postgres v1.6.0
|
gorm.io/driver/postgres v1.6.0
|
||||||
@@ -135,7 +136,6 @@ require (
|
|||||||
golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 // indirect
|
golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 // indirect
|
||||||
golang.org/x/net v0.43.0 // indirect
|
golang.org/x/net v0.43.0 // indirect
|
||||||
golang.org/x/oauth2 v0.27.0 // indirect
|
golang.org/x/oauth2 v0.27.0 // indirect
|
||||||
golang.org/x/sync v0.16.0 // indirect
|
|
||||||
golang.org/x/sys v0.35.0 // indirect
|
golang.org/x/sys v0.35.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250218202821-56aae31c358a // indirect
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sa
|
|||||||
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY=
|
github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY=
|
||||||
github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok=
|
github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok=
|
||||||
|
github.com/gin-contrib/slog v1.1.0 h1:K9MVNrETT6r/C3u2Aheer/gxwVeVqrGL0hXlsmv3fm4=
|
||||||
|
github.com/gin-contrib/slog v1.1.0/go.mod h1:PvNXQVXcVOAaaiJR84LV1/xlQHIaXi9ygEXyBkmjdkY=
|
||||||
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||||
github.com/gin-gonic/gin v1.10.1 h1:T0ujvqyCSqRopADpgPgiTT63DUQVSfojyME59Ei63pQ=
|
github.com/gin-gonic/gin v1.10.1 h1:T0ujvqyCSqRopADpgPgiTT63DUQVSfojyME59Ei63pQ=
|
||||||
@@ -241,8 +243,6 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG
|
|||||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
github.com/samber/slog-gin v1.15.1 h1:jsnfr+S5HQPlz9pFPA3tOmKW7wN/znyZiE6hncucrTM=
|
|
||||||
github.com/samber/slog-gin v1.15.1/go.mod h1:mPAEinK/g2jPLauuWO11m3Q0Ca7aG4k9XjXjXY8IhMQ=
|
|
||||||
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
|
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
|
||||||
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
|
||||||
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
|
github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo=
|
||||||
|
|||||||
120
backend/internal/bootstrap/app_images_bootstrap.go
Normal file
120
backend/internal/bootstrap/app_images_bootstrap.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package bootstrap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/resources"
|
||||||
|
)
|
||||||
|
|
||||||
|
// initApplicationImages copies the images from the images directory to the application-images directory
|
||||||
|
// and returns a map containing the detected file extensions in the application-images directory.
|
||||||
|
func initApplicationImages() (map[string]string, error) {
|
||||||
|
// Previous versions of images
|
||||||
|
// If these are found, they are deleted
|
||||||
|
legacyImageHashes := imageHashMap{
|
||||||
|
"background.jpg": mustDecodeHex("138d510030ed845d1d74de34658acabff562d306476454369a60ab8ade31933f"),
|
||||||
|
}
|
||||||
|
|
||||||
|
dirPath := common.EnvConfig.UploadPath + "/application-images"
|
||||||
|
|
||||||
|
sourceFiles, err := resources.FS.ReadDir("images")
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
return nil, fmt.Errorf("failed to read directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
destinationFiles, err := os.ReadDir(dirPath)
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
return nil, fmt.Errorf("failed to read directory: %w", err)
|
||||||
|
}
|
||||||
|
dstNameToExt := make(map[string]string, len(destinationFiles))
|
||||||
|
for _, f := range destinationFiles {
|
||||||
|
if f.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := f.Name()
|
||||||
|
nameWithoutExt, ext := utils.SplitFileName(name)
|
||||||
|
destFilePath := path.Join(dirPath, name)
|
||||||
|
|
||||||
|
// Skip directories
|
||||||
|
if f.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
h, err := utils.CreateSha256FileHash(destFilePath)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("Failed to get hash for file", slog.String("name", name), slog.Any("error", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the file is a legacy one - if so, delete it
|
||||||
|
if legacyImageHashes.Contains(h) {
|
||||||
|
slog.Info("Found legacy application image that will be removed", slog.String("name", name))
|
||||||
|
err = os.Remove(destFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to remove legacy file '%s': %w", name, err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track existing files
|
||||||
|
dstNameToExt[nameWithoutExt] = ext
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy images from the images directory to the application-images directory if they don't already exist
|
||||||
|
for _, sourceFile := range sourceFiles {
|
||||||
|
if sourceFile.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
name := sourceFile.Name()
|
||||||
|
nameWithoutExt, ext := utils.SplitFileName(name)
|
||||||
|
srcFilePath := path.Join("images", name)
|
||||||
|
destFilePath := path.Join(dirPath, name)
|
||||||
|
|
||||||
|
// Skip if there's already an image at the path
|
||||||
|
// We do not check the extension because users could have uploaded a different one
|
||||||
|
if _, exists := dstNameToExt[nameWithoutExt]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Info("Writing new application image", slog.String("name", name))
|
||||||
|
err := utils.CopyEmbeddedFileToDisk(srcFilePath, destFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to copy file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track the newly copied file so it can be included in the extensions map later
|
||||||
|
dstNameToExt[nameWithoutExt] = ext
|
||||||
|
}
|
||||||
|
|
||||||
|
return dstNameToExt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type imageHashMap map[string][]byte
|
||||||
|
|
||||||
|
func (m imageHashMap) Contains(target []byte) bool {
|
||||||
|
if len(target) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, h := range m {
|
||||||
|
if bytes.Equal(h, target) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustDecodeHex(str string) []byte {
|
||||||
|
b, err := hex.DecodeString(str)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
package bootstrap
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
|
||||||
"github.com/pocket-id/pocket-id/backend/resources"
|
|
||||||
)
|
|
||||||
|
|
||||||
// initApplicationImages copies the images from the images directory to the application-images directory
|
|
||||||
func initApplicationImages() error {
|
|
||||||
dirPath := common.EnvConfig.UploadPath + "/application-images"
|
|
||||||
|
|
||||||
sourceFiles, err := resources.FS.ReadDir("images")
|
|
||||||
if err != nil && !os.IsNotExist(err) {
|
|
||||||
return fmt.Errorf("failed to read directory: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
destinationFiles, err := os.ReadDir(dirPath)
|
|
||||||
if err != nil && !os.IsNotExist(err) {
|
|
||||||
return fmt.Errorf("failed to read directory: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy images from the images directory to the application-images directory if they don't already exist
|
|
||||||
for _, sourceFile := range sourceFiles {
|
|
||||||
if sourceFile.IsDir() || imageAlreadyExists(sourceFile.Name(), destinationFiles) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
srcFilePath := path.Join("images", sourceFile.Name())
|
|
||||||
destFilePath := path.Join(dirPath, sourceFile.Name())
|
|
||||||
|
|
||||||
err := utils.CopyEmbeddedFileToDisk(srcFilePath, destFilePath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to copy file: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func imageAlreadyExists(fileName string, destinationFiles []os.DirEntry) bool {
|
|
||||||
for _, destinationFile := range destinationFiles {
|
|
||||||
sourceFileWithoutExtension := getImageNameWithoutExtension(fileName)
|
|
||||||
destinationFileWithoutExtension := getImageNameWithoutExtension(destinationFile.Name())
|
|
||||||
|
|
||||||
if sourceFileWithoutExtension == destinationFileWithoutExtension {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func getImageNameWithoutExtension(fileName string) string {
|
|
||||||
idx := strings.LastIndexByte(fileName, '.')
|
|
||||||
if idx < 1 {
|
|
||||||
// No dot found, or fileName starts with a dot
|
|
||||||
return fileName
|
|
||||||
}
|
|
||||||
|
|
||||||
return fileName[:idx]
|
|
||||||
}
|
|
||||||
@@ -21,7 +21,7 @@ func Bootstrap(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
slog.InfoContext(ctx, "Pocket ID is starting")
|
slog.InfoContext(ctx, "Pocket ID is starting")
|
||||||
|
|
||||||
err = initApplicationImages()
|
imageExtensions, err := initApplicationImages()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize application images: %w", err)
|
return fmt.Errorf("failed to initialize application images: %w", err)
|
||||||
}
|
}
|
||||||
@@ -33,7 +33,7 @@ func Bootstrap(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create all services
|
// Create all services
|
||||||
svc, err := initServices(ctx, db, httpClient)
|
svc, err := initServices(ctx, db, httpClient, imageExtensions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize services: %w", err)
|
return fmt.Errorf("failed to initialize services: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package bootstrap
|
package bootstrap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -140,6 +141,7 @@ func connectDatabase() (db *gorm.DB, err error) {
|
|||||||
var dialector gorm.Dialector
|
var dialector gorm.Dialector
|
||||||
|
|
||||||
// Choose the correct database provider
|
// Choose the correct database provider
|
||||||
|
var onConnFn func(conn *sql.DB)
|
||||||
switch common.EnvConfig.DbProvider {
|
switch common.EnvConfig.DbProvider {
|
||||||
case common.DbProviderSqlite:
|
case common.DbProviderSqlite:
|
||||||
if common.EnvConfig.DbConnectionString == "" {
|
if common.EnvConfig.DbConnectionString == "" {
|
||||||
@@ -148,7 +150,7 @@ func connectDatabase() (db *gorm.DB, err error) {
|
|||||||
|
|
||||||
sqliteutil.RegisterSqliteFunctions()
|
sqliteutil.RegisterSqliteFunctions()
|
||||||
|
|
||||||
connString, dbPath, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString)
|
connString, dbPath, isMemoryDB, err := parseSqliteConnectionString(common.EnvConfig.DbConnectionString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -159,6 +161,14 @@ func connectDatabase() (db *gorm.DB, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isMemoryDB {
|
||||||
|
// For in-memory SQLite databases, we must limit to 1 open connection at the same time, or they won't see the whole data
|
||||||
|
// The other workaround, of using shared caches, doesn't work well with multiple write transactions trying to happen at once
|
||||||
|
onConnFn = func(conn *sql.DB) {
|
||||||
|
conn.SetMaxOpenConns(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
dialector = sqlite.Open(connString)
|
dialector = sqlite.Open(connString)
|
||||||
case common.DbProviderPostgres:
|
case common.DbProviderPostgres:
|
||||||
if common.EnvConfig.DbConnectionString == "" {
|
if common.EnvConfig.DbConnectionString == "" {
|
||||||
@@ -176,6 +186,16 @@ func connectDatabase() (db *gorm.DB, err error) {
|
|||||||
})
|
})
|
||||||
if err == nil {
|
if err == nil {
|
||||||
slog.Info("Connected to database", slog.String("provider", string(common.EnvConfig.DbProvider)))
|
slog.Info("Connected to database", slog.String("provider", string(common.EnvConfig.DbProvider)))
|
||||||
|
|
||||||
|
if onConnFn != nil {
|
||||||
|
conn, err := db.DB()
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("Failed to get database connection, will retry in 3s", slog.Int("attempt", i), slog.String("provider", string(common.EnvConfig.DbProvider)), slog.Any("error", err))
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
}
|
||||||
|
onConnFn(conn)
|
||||||
|
}
|
||||||
|
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,18 +208,18 @@ func connectDatabase() (db *gorm.DB, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseSqliteConnectionString(connString string) (parsedConnString string, dbPath string, err error) {
|
func parseSqliteConnectionString(connString string) (parsedConnString string, dbPath string, isMemoryDB bool, err error) {
|
||||||
if !strings.HasPrefix(connString, "file:") {
|
if !strings.HasPrefix(connString, "file:") {
|
||||||
connString = "file:" + connString
|
connString = "file:" + connString
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if we're using an in-memory database
|
// Check if we're using an in-memory database
|
||||||
isMemoryDB := isSqliteInMemory(connString)
|
isMemoryDB = isSqliteInMemory(connString)
|
||||||
|
|
||||||
// Parse the connection string
|
// Parse the connection string
|
||||||
connStringUrl, err := url.Parse(connString)
|
connStringUrl, err := url.Parse(connString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("failed to parse SQLite connection string: %w", err)
|
return "", "", false, fmt.Errorf("failed to parse SQLite connection string: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert options for the old SQLite driver to the new one
|
// Convert options for the old SQLite driver to the new one
|
||||||
@@ -208,7 +228,7 @@ func parseSqliteConnectionString(connString string) (parsedConnString string, db
|
|||||||
// Add the default and required params
|
// Add the default and required params
|
||||||
err = addSqliteDefaultParameters(connStringUrl, isMemoryDB)
|
err = addSqliteDefaultParameters(connStringUrl, isMemoryDB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("invalid SQLite connection string: %w", err)
|
return "", "", false, fmt.Errorf("invalid SQLite connection string: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the absolute path to the database
|
// Get the absolute path to the database
|
||||||
@@ -217,10 +237,10 @@ func parseSqliteConnectionString(connString string) (parsedConnString string, db
|
|||||||
idx := strings.IndexRune(parsedConnString, '?')
|
idx := strings.IndexRune(parsedConnString, '?')
|
||||||
dbPath, err = filepath.Abs(parsedConnString[len("file:"):idx])
|
dbPath, err = filepath.Abs(parsedConnString[len("file:"):idx])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("failed to determine absolute path to the database: %w", err)
|
return "", "", false, fmt.Errorf("failed to determine absolute path to the database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return parsedConnString, dbPath, nil
|
return parsedConnString, dbPath, isMemoryDB, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// The official C implementation of SQLite allows some additional properties in the connection string
|
// The official C implementation of SQLite allows some additional properties in the connection string
|
||||||
@@ -272,11 +292,6 @@ func addSqliteDefaultParameters(connStringUrl *url.URL, isMemoryDB bool) error {
|
|||||||
qs = make(url.Values, 2)
|
qs = make(url.Values, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the database is in-memory, we must ensure that cache=shared is set
|
|
||||||
if isMemoryDB {
|
|
||||||
qs["cache"] = []string{"shared"}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the database is read-only or immutable
|
// Check if the database is read-only or immutable
|
||||||
isReadOnly := false
|
isReadOnly := false
|
||||||
if len(qs["mode"]) > 0 {
|
if len(qs["mode"]) > 0 {
|
||||||
@@ -422,17 +437,18 @@ func getGormLogger() gormLogger.Interface {
|
|||||||
slogGorm.WithErrorField("error"),
|
slogGorm.WithErrorField("error"),
|
||||||
)
|
)
|
||||||
|
|
||||||
if common.EnvConfig.AppEnv == "production" {
|
if common.EnvConfig.LogLevel == "debug" {
|
||||||
loggerOpts = append(loggerOpts,
|
|
||||||
slogGorm.SetLogLevel(slogGorm.DefaultLogType, slog.LevelWarn),
|
|
||||||
slogGorm.WithIgnoreTrace(),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
loggerOpts = append(loggerOpts,
|
loggerOpts = append(loggerOpts,
|
||||||
slogGorm.SetLogLevel(slogGorm.DefaultLogType, slog.LevelDebug),
|
slogGorm.SetLogLevel(slogGorm.DefaultLogType, slog.LevelDebug),
|
||||||
slogGorm.WithRecordNotFoundError(),
|
slogGorm.WithRecordNotFoundError(),
|
||||||
slogGorm.WithTraceAll(),
|
slogGorm.WithTraceAll(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
loggerOpts = append(loggerOpts,
|
||||||
|
slogGorm.SetLogLevel(slogGorm.DefaultLogType, slog.LevelWarn),
|
||||||
|
slogGorm.WithIgnoreTrace(),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return slogGorm.New(loggerOpts...)
|
return slogGorm.New(loggerOpts...)
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ func TestAddSqliteDefaultParameters(t *testing.T) {
|
|||||||
name: "in-memory database",
|
name: "in-memory database",
|
||||||
input: "file::memory:",
|
input: "file::memory:",
|
||||||
isMemoryDB: true,
|
isMemoryDB: true,
|
||||||
expected: "file::memory:?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28MEMORY%29&_txlock=immediate&cache=shared",
|
expected: "file::memory:?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28MEMORY%29&_txlock=immediate",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "read-only database with mode=ro",
|
name: "read-only database with mode=ro",
|
||||||
@@ -249,12 +249,6 @@ func TestAddSqliteDefaultParameters(t *testing.T) {
|
|||||||
isMemoryDB: false,
|
isMemoryDB: false,
|
||||||
expected: "file:test.db?_pragma=busy_timeout%283000%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28TRUNCATE%29&_pragma=synchronous%28NORMAL%29&_txlock=immediate",
|
expected: "file:test.db?_pragma=busy_timeout%283000%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28TRUNCATE%29&_pragma=synchronous%28NORMAL%29&_txlock=immediate",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "in-memory database with cache already set",
|
|
||||||
input: "file::memory:?cache=private",
|
|
||||||
isMemoryDB: true,
|
|
||||||
expected: "file::memory:?_pragma=busy_timeout%282500%29&_pragma=foreign_keys%281%29&_pragma=journal_mode%28MEMORY%29&_txlock=immediate&cache=shared",
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "database with mode=rw (not read-only)",
|
name: "database with mode=rw (not read-only)",
|
||||||
input: "file:test.db?mode=rw",
|
input: "file:test.db?mode=rw",
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
sloggin "github.com/gin-contrib/slog"
|
||||||
|
|
||||||
"github.com/lmittmann/tint"
|
"github.com/lmittmann/tint"
|
||||||
"github.com/mattn/go-isatty"
|
"github.com/mattn/go-isatty"
|
||||||
"go.opentelemetry.io/contrib/bridges/otelslog"
|
"go.opentelemetry.io/contrib/bridges/otelslog"
|
||||||
@@ -89,28 +91,19 @@ func initOtelLogging(ctx context.Context, resource *resource.Resource) error {
|
|||||||
return fmt.Errorf("failed to initialize OpenTelemetry log exporter: %w", err)
|
return fmt.Errorf("failed to initialize OpenTelemetry log exporter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
level := slog.LevelDebug
|
level, _ := sloggin.ParseLevel(common.EnvConfig.LogLevel)
|
||||||
if common.EnvConfig.AppEnv == "production" {
|
|
||||||
level = slog.LevelInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the handler
|
// Create the handler
|
||||||
var handler slog.Handler
|
var handler slog.Handler
|
||||||
switch {
|
if common.EnvConfig.LogJSON {
|
||||||
case common.EnvConfig.LogJSON:
|
|
||||||
// Log as JSON if configured
|
|
||||||
handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
|
handler = slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
|
||||||
Level: level,
|
Level: level,
|
||||||
})
|
})
|
||||||
case isatty.IsTerminal(os.Stdout.Fd()):
|
} else {
|
||||||
// Enable colors if we have a TTY
|
|
||||||
handler = tint.NewHandler(os.Stdout, &tint.Options{
|
handler = tint.NewHandler(os.Stdout, &tint.Options{
|
||||||
TimeFormat: time.StampMilli,
|
TimeFormat: time.Stamp,
|
||||||
Level: level,
|
Level: level,
|
||||||
})
|
NoColor: !isatty.IsTerminal(os.Stdout.Fd()),
|
||||||
default:
|
|
||||||
handler = slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{
|
|
||||||
Level: level,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
sloggin "github.com/gin-contrib/slog"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
sloggin "github.com/samber/slog-gin"
|
|
||||||
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
|
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@@ -49,30 +49,8 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
|
|||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// do not log these URLs
|
|
||||||
loggerSkipPathsPrefix := []string{
|
|
||||||
"GET /application-configuration/logo",
|
|
||||||
"GET /application-configuration/background-image",
|
|
||||||
"GET /application-configuration/favicon",
|
|
||||||
"GET /_app",
|
|
||||||
"GET /fonts",
|
|
||||||
"GET /healthz",
|
|
||||||
"HEAD /healthz",
|
|
||||||
}
|
|
||||||
|
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
r.Use(sloggin.NewWithConfig(slog.Default(), sloggin.Config{
|
initLogger(r)
|
||||||
Filters: []sloggin.Filter{
|
|
||||||
func(c *gin.Context) bool {
|
|
||||||
for _, prefix := range loggerSkipPathsPrefix {
|
|
||||||
if strings.HasPrefix(c.Request.Method+" "+c.Request.URL.String(), prefix) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}))
|
|
||||||
|
|
||||||
if !common.EnvConfig.TrustProxy {
|
if !common.EnvConfig.TrustProxy {
|
||||||
_ = r.SetTrustedProxies(nil)
|
_ = r.SetTrustedProxies(nil)
|
||||||
@@ -86,6 +64,7 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
|
|||||||
|
|
||||||
// Setup global middleware
|
// Setup global middleware
|
||||||
r.Use(middleware.NewCorsMiddleware().Add())
|
r.Use(middleware.NewCorsMiddleware().Add())
|
||||||
|
r.Use(middleware.NewCspMiddleware().Add())
|
||||||
r.Use(middleware.NewErrorHandlerMiddleware().Add())
|
r.Use(middleware.NewErrorHandlerMiddleware().Add())
|
||||||
|
|
||||||
err := frontend.RegisterFrontend(r)
|
err := frontend.RegisterFrontend(r)
|
||||||
@@ -106,9 +85,11 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
|
|||||||
controller.NewOidcController(apiGroup, authMiddleware, fileSizeLimitMiddleware, svc.oidcService, svc.jwtService)
|
controller.NewOidcController(apiGroup, authMiddleware, fileSizeLimitMiddleware, svc.oidcService, svc.jwtService)
|
||||||
controller.NewUserController(apiGroup, authMiddleware, middleware.NewRateLimitMiddleware(), svc.userService, svc.appConfigService)
|
controller.NewUserController(apiGroup, authMiddleware, middleware.NewRateLimitMiddleware(), svc.userService, svc.appConfigService)
|
||||||
controller.NewAppConfigController(apiGroup, authMiddleware, svc.appConfigService, svc.emailService, svc.ldapService)
|
controller.NewAppConfigController(apiGroup, authMiddleware, svc.appConfigService, svc.emailService, svc.ldapService)
|
||||||
|
controller.NewAppImagesController(apiGroup, authMiddleware, svc.appImagesService)
|
||||||
controller.NewAuditLogController(apiGroup, svc.auditLogService, authMiddleware)
|
controller.NewAuditLogController(apiGroup, svc.auditLogService, authMiddleware)
|
||||||
controller.NewUserGroupController(apiGroup, authMiddleware, svc.userGroupService)
|
controller.NewUserGroupController(apiGroup, authMiddleware, svc.userGroupService)
|
||||||
controller.NewCustomClaimController(apiGroup, authMiddleware, svc.customClaimService)
|
controller.NewCustomClaimController(apiGroup, authMiddleware, svc.customClaimService)
|
||||||
|
controller.NewVersionController(apiGroup, svc.versionService)
|
||||||
|
|
||||||
// Add test controller in non-production environments
|
// Add test controller in non-production environments
|
||||||
if common.EnvConfig.AppEnv != "production" {
|
if common.EnvConfig.AppEnv != "production" {
|
||||||
@@ -138,6 +119,7 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
|
|||||||
if common.EnvConfig.UnixSocket != "" {
|
if common.EnvConfig.UnixSocket != "" {
|
||||||
network = "unix"
|
network = "unix"
|
||||||
addr = common.EnvConfig.UnixSocket
|
addr = common.EnvConfig.UnixSocket
|
||||||
|
os.Remove(addr) // remove dangling the socket file to avoid file-exist error
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err := net.Listen(network, addr) //nolint:noctx
|
listener, err := net.Listen(network, addr) //nolint:noctx
|
||||||
@@ -198,3 +180,29 @@ func initRouterInternal(db *gorm.DB, svc *services) (utils.Service, error) {
|
|||||||
|
|
||||||
return runFn, nil
|
return runFn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initLogger(r *gin.Engine) {
|
||||||
|
loggerSkipPathsPrefix := []string{
|
||||||
|
"GET /api/application-images/logo",
|
||||||
|
"GET /api/application-images/background",
|
||||||
|
"GET /api/application-images/favicon",
|
||||||
|
"GET /_app",
|
||||||
|
"GET /fonts",
|
||||||
|
"GET /healthz",
|
||||||
|
"HEAD /healthz",
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Use(sloggin.SetLogger(
|
||||||
|
sloggin.WithLogger(func(_ *gin.Context, _ *slog.Logger) *slog.Logger {
|
||||||
|
return slog.Default()
|
||||||
|
}),
|
||||||
|
sloggin.WithSkipper(func(c *gin.Context) bool {
|
||||||
|
for _, prefix := range loggerSkipPathsPrefix {
|
||||||
|
if strings.HasPrefix(c.Request.Method+" "+c.Request.URL.String(), prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
type services struct {
|
type services struct {
|
||||||
appConfigService *service.AppConfigService
|
appConfigService *service.AppConfigService
|
||||||
|
appImagesService *service.AppImagesService
|
||||||
emailService *service.EmailService
|
emailService *service.EmailService
|
||||||
geoLiteService *service.GeoLiteService
|
geoLiteService *service.GeoLiteService
|
||||||
auditLogService *service.AuditLogService
|
auditLogService *service.AuditLogService
|
||||||
@@ -23,10 +24,11 @@ type services struct {
|
|||||||
userGroupService *service.UserGroupService
|
userGroupService *service.UserGroupService
|
||||||
ldapService *service.LdapService
|
ldapService *service.LdapService
|
||||||
apiKeyService *service.ApiKeyService
|
apiKeyService *service.ApiKeyService
|
||||||
|
versionService *service.VersionService
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initializes all services
|
// Initializes all services
|
||||||
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (svc *services, err error) {
|
func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client, imageExtensions map[string]string) (svc *services, err error) {
|
||||||
svc = &services{}
|
svc = &services{}
|
||||||
|
|
||||||
svc.appConfigService, err = service.NewAppConfigService(ctx, db)
|
svc.appConfigService, err = service.NewAppConfigService(ctx, db)
|
||||||
@@ -34,6 +36,8 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (sv
|
|||||||
return nil, fmt.Errorf("failed to create app config service: %w", err)
|
return nil, fmt.Errorf("failed to create app config service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
svc.appImagesService = service.NewAppImagesService(imageExtensions)
|
||||||
|
|
||||||
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
|
svc.emailService, err = service.NewEmailService(db, svc.appConfigService)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create email service: %w", err)
|
return nil, fmt.Errorf("failed to create email service: %w", err)
|
||||||
@@ -52,7 +56,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (sv
|
|||||||
return nil, fmt.Errorf("failed to create WebAuthn service: %w", err)
|
return nil, fmt.Errorf("failed to create WebAuthn service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService)
|
svc.oidcService, err = service.NewOidcService(ctx, db, svc.jwtService, svc.appConfigService, svc.auditLogService, svc.customClaimService, svc.webauthnService, httpClient)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
|
return nil, fmt.Errorf("failed to create OIDC service: %w", err)
|
||||||
}
|
}
|
||||||
@@ -62,5 +66,7 @@ func initServices(ctx context.Context, db *gorm.DB, httpClient *http.Client) (sv
|
|||||||
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
|
svc.ldapService = service.NewLdapService(db, httpClient, svc.appConfigService, svc.userService, svc.userGroupService)
|
||||||
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
svc.apiKeyService = service.NewApiKeyService(db, svc.emailService)
|
||||||
|
|
||||||
|
svc.versionService = service.NewVersionService(httpClient)
|
||||||
|
|
||||||
return svc, nil
|
return svc, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/caarlos0/env/v11"
|
"github.com/caarlos0/env/v11"
|
||||||
|
sloggin "github.com/gin-contrib/slog"
|
||||||
_ "github.com/joho/godotenv/autoload"
|
_ "github.com/joho/godotenv/autoload"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,19 +29,21 @@ const (
|
|||||||
DbProviderPostgres DbProvider = "postgres"
|
DbProviderPostgres DbProvider = "postgres"
|
||||||
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
|
MaxMindGeoLiteCityUrl string = "https://download.maxmind.com/app/geoip_download?edition_id=GeoLite2-City&license_key=%s&suffix=tar.gz"
|
||||||
defaultSqliteConnString string = "data/pocket-id.db"
|
defaultSqliteConnString string = "data/pocket-id.db"
|
||||||
|
AppUrl string = "http://localhost:1411"
|
||||||
)
|
)
|
||||||
|
|
||||||
type EnvConfigSchema struct {
|
type EnvConfigSchema struct {
|
||||||
AppEnv string `env:"APP_ENV"`
|
AppEnv string `env:"APP_ENV" options:"toLower"`
|
||||||
AppURL string `env:"APP_URL"`
|
LogLevel string `env:"LOG_LEVEL" options:"toLower"`
|
||||||
DbProvider DbProvider `env:"DB_PROVIDER"`
|
AppURL string `env:"APP_URL" options:"toLower"`
|
||||||
|
DbProvider DbProvider `env:"DB_PROVIDER" options:"toLower"`
|
||||||
DbConnectionString string `env:"DB_CONNECTION_STRING" options:"file"`
|
DbConnectionString string `env:"DB_CONNECTION_STRING" options:"file"`
|
||||||
UploadPath string `env:"UPLOAD_PATH"`
|
UploadPath string `env:"UPLOAD_PATH"`
|
||||||
KeysPath string `env:"KEYS_PATH"`
|
KeysPath string `env:"KEYS_PATH"`
|
||||||
KeysStorage string `env:"KEYS_STORAGE"`
|
KeysStorage string `env:"KEYS_STORAGE"`
|
||||||
EncryptionKey []byte `env:"ENCRYPTION_KEY" options:"file"`
|
EncryptionKey []byte `env:"ENCRYPTION_KEY" options:"file"`
|
||||||
Port string `env:"PORT"`
|
Port string `env:"PORT"`
|
||||||
Host string `env:"HOST"`
|
Host string `env:"HOST" options:"toLower"`
|
||||||
UnixSocket string `env:"UNIX_SOCKET"`
|
UnixSocket string `env:"UNIX_SOCKET"`
|
||||||
UnixSocketMode string `env:"UNIX_SOCKET_MODE"`
|
UnixSocketMode string `env:"UNIX_SOCKET_MODE"`
|
||||||
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY" options:"file"`
|
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY" options:"file"`
|
||||||
@@ -53,6 +57,7 @@ type EnvConfigSchema struct {
|
|||||||
TrustProxy bool `env:"TRUST_PROXY"`
|
TrustProxy bool `env:"TRUST_PROXY"`
|
||||||
AnalyticsDisabled bool `env:"ANALYTICS_DISABLED"`
|
AnalyticsDisabled bool `env:"ANALYTICS_DISABLED"`
|
||||||
AllowDowngrade bool `env:"ALLOW_DOWNGRADE"`
|
AllowDowngrade bool `env:"ALLOW_DOWNGRADE"`
|
||||||
|
InternalAppURL string `env:"INTERNAL_APP_URL"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var EnvConfig = defaultConfig()
|
var EnvConfig = defaultConfig()
|
||||||
@@ -68,13 +73,14 @@ func init() {
|
|||||||
func defaultConfig() EnvConfigSchema {
|
func defaultConfig() EnvConfigSchema {
|
||||||
return EnvConfigSchema{
|
return EnvConfigSchema{
|
||||||
AppEnv: "production",
|
AppEnv: "production",
|
||||||
|
LogLevel: "info",
|
||||||
DbProvider: "sqlite",
|
DbProvider: "sqlite",
|
||||||
DbConnectionString: "",
|
DbConnectionString: "",
|
||||||
UploadPath: "data/uploads",
|
UploadPath: "data/uploads",
|
||||||
KeysPath: "data/keys",
|
KeysPath: "data/keys",
|
||||||
KeysStorage: "", // "database" or "file"
|
KeysStorage: "", // "database" or "file"
|
||||||
EncryptionKey: nil,
|
EncryptionKey: nil,
|
||||||
AppURL: "http://localhost:1411",
|
AppURL: AppUrl,
|
||||||
Port: "1411",
|
Port: "1411",
|
||||||
Host: "0.0.0.0",
|
Host: "0.0.0.0",
|
||||||
UnixSocket: "",
|
UnixSocket: "",
|
||||||
@@ -89,6 +95,7 @@ func defaultConfig() EnvConfigSchema {
|
|||||||
TrustProxy: false,
|
TrustProxy: false,
|
||||||
AnalyticsDisabled: false,
|
AnalyticsDisabled: false,
|
||||||
AllowDowngrade: false,
|
AllowDowngrade: false,
|
||||||
|
InternalAppURL: "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,26 +113,40 @@ func parseEnvConfig() error {
|
|||||||
return fmt.Errorf("error parsing env config: %w", err)
|
return fmt.Errorf("error parsing env config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = resolveFileBasedEnvVariables(&EnvConfig)
|
err = prepareEnvConfig(&EnvConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error preparing env config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = validateEnvConfig(&EnvConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the environment variables
|
return nil
|
||||||
switch EnvConfig.DbProvider {
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateEnvConfig checks the EnvConfig for required fields and valid values
|
||||||
|
func validateEnvConfig(config *EnvConfigSchema) error {
|
||||||
|
if _, err := sloggin.ParseLevel(config.LogLevel); err != nil {
|
||||||
|
return errors.New("invalid LOG_LEVEL value. Must be 'debug', 'info', 'warn' or 'error'")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch config.DbProvider {
|
||||||
case DbProviderSqlite:
|
case DbProviderSqlite:
|
||||||
if EnvConfig.DbConnectionString == "" {
|
if config.DbConnectionString == "" {
|
||||||
EnvConfig.DbConnectionString = defaultSqliteConnString
|
config.DbConnectionString = defaultSqliteConnString
|
||||||
}
|
}
|
||||||
case DbProviderPostgres:
|
case DbProviderPostgres:
|
||||||
if EnvConfig.DbConnectionString == "" {
|
if config.DbConnectionString == "" {
|
||||||
return errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database")
|
return errors.New("missing required env var 'DB_CONNECTION_STRING' for Postgres database")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return errors.New("invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
|
return errors.New("invalid DB_PROVIDER value. Must be 'sqlite' or 'postgres'")
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedAppUrl, err := url.Parse(EnvConfig.AppURL)
|
parsedAppUrl, err := url.Parse(config.AppURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("APP_URL is not a valid URL")
|
return errors.New("APP_URL is not a valid URL")
|
||||||
}
|
}
|
||||||
@@ -133,25 +154,58 @@ func parseEnvConfig() error {
|
|||||||
return errors.New("APP_URL must not contain a path")
|
return errors.New("APP_URL must not contain a path")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch EnvConfig.KeysStorage {
|
// Derive INTERNAL_APP_URL from APP_URL if not set; validate only when provided
|
||||||
|
if config.InternalAppURL == "" {
|
||||||
|
config.InternalAppURL = config.AppURL
|
||||||
|
} else {
|
||||||
|
parsedInternalAppUrl, err := url.Parse(config.InternalAppURL)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("INTERNAL_APP_URL is not a valid URL")
|
||||||
|
}
|
||||||
|
if parsedInternalAppUrl.Path != "" {
|
||||||
|
return errors.New("INTERNAL_APP_URL must not contain a path")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch config.KeysStorage {
|
||||||
// KeysStorage defaults to "file" if empty
|
// KeysStorage defaults to "file" if empty
|
||||||
case "":
|
case "":
|
||||||
EnvConfig.KeysStorage = "file"
|
config.KeysStorage = "file"
|
||||||
case "database":
|
case "database":
|
||||||
if EnvConfig.EncryptionKey == nil {
|
if config.EncryptionKey == nil {
|
||||||
return errors.New("ENCRYPTION_KEY must be non-empty when KEYS_STORAGE is database")
|
return errors.New("ENCRYPTION_KEY must be non-empty when KEYS_STORAGE is database")
|
||||||
}
|
}
|
||||||
case "file":
|
case "file":
|
||||||
// All good, these are valid values
|
// All good, these are valid values
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", EnvConfig.KeysStorage)
|
return fmt.Errorf("invalid value for KEYS_STORAGE: %s", config.KeysStorage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate LOCAL_IPV6_RANGES
|
||||||
|
ranges := strings.Split(config.LocalIPv6Ranges, ",")
|
||||||
|
for _, rangeStr := range ranges {
|
||||||
|
rangeStr = strings.TrimSpace(rangeStr)
|
||||||
|
if rangeStr == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ipNet, err := net.ParseCIDR(rangeStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid LOCAL_IPV6_RANGES '%s': %w", rangeStr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipNet.IP.To4() != nil {
|
||||||
|
return fmt.Errorf("range '%s' is not a valid IPv6 range", rangeStr)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveFileBasedEnvVariables uses reflection to automatically resolve file-based secrets
|
// prepareEnvConfig processes special options for EnvConfig fields
|
||||||
func resolveFileBasedEnvVariables(config *EnvConfigSchema) error {
|
func prepareEnvConfig(config *EnvConfigSchema) error {
|
||||||
val := reflect.ValueOf(config).Elem()
|
val := reflect.ValueOf(config).Elem()
|
||||||
typ := val.Type()
|
typ := val.Type()
|
||||||
|
|
||||||
@@ -159,48 +213,65 @@ func resolveFileBasedEnvVariables(config *EnvConfigSchema) error {
|
|||||||
field := val.Field(i)
|
field := val.Field(i)
|
||||||
fieldType := typ.Field(i)
|
fieldType := typ.Field(i)
|
||||||
|
|
||||||
// Only process string and []byte fields
|
|
||||||
isString := field.Kind() == reflect.String
|
|
||||||
isByteSlice := field.Kind() == reflect.Slice && field.Type().Elem().Kind() == reflect.Uint8
|
|
||||||
if !isString && !isByteSlice {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only process fields with the "options" tag set to "file"
|
|
||||||
optionsTag := fieldType.Tag.Get("options")
|
optionsTag := fieldType.Tag.Get("options")
|
||||||
if optionsTag != "file" {
|
options := strings.Split(optionsTag, ",")
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only process fields with the "env" tag
|
for _, option := range options {
|
||||||
envTag := fieldType.Tag.Get("env")
|
switch option {
|
||||||
if envTag == "" {
|
case "toLower":
|
||||||
continue
|
if field.Kind() == reflect.String {
|
||||||
}
|
field.SetString(strings.ToLower(field.String()))
|
||||||
|
}
|
||||||
envVarName := envTag
|
case "file":
|
||||||
if commaIndex := len(envTag); commaIndex > 0 {
|
err := resolveFileBasedEnvVariable(field, fieldType)
|
||||||
envVarName = envTag[:commaIndex]
|
if err != nil {
|
||||||
}
|
return err
|
||||||
|
}
|
||||||
// If the file environment variable is not set, skip
|
}
|
||||||
envVarFileName := envVarName + "_FILE"
|
|
||||||
envVarFileValue := os.Getenv(envVarFileName)
|
|
||||||
if envVarFileValue == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
fileContent, err := os.ReadFile(envVarFileValue)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to read file for env var %s: %w", envVarFileName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if isString {
|
|
||||||
field.SetString(strings.TrimSpace(string(fileContent)))
|
|
||||||
} else {
|
|
||||||
field.SetBytes(fileContent)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveFileBasedEnvVariable checks if an environment variable with the suffix "_FILE" is set,
|
||||||
|
// reads the content of the file specified by that variable, and sets the corresponding field's value.
|
||||||
|
func resolveFileBasedEnvVariable(field reflect.Value, fieldType reflect.StructField) error {
|
||||||
|
// Only process string and []byte fields
|
||||||
|
isString := field.Kind() == reflect.String
|
||||||
|
isByteSlice := field.Kind() == reflect.Slice && field.Type().Elem().Kind() == reflect.Uint8
|
||||||
|
if !isString && !isByteSlice {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process fields with the "env" tag
|
||||||
|
envTag := fieldType.Tag.Get("env")
|
||||||
|
if envTag == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
envVarName := envTag
|
||||||
|
if commaIndex := len(envTag); commaIndex > 0 {
|
||||||
|
envVarName = envTag[:commaIndex]
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the file environment variable is not set, skip
|
||||||
|
envVarFileName := envVarName + "_FILE"
|
||||||
|
envVarFileValue := os.Getenv(envVarFileName)
|
||||||
|
if envVarFileValue == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fileContent, err := os.ReadFile(envVarFileValue)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read file for env var %s: %w", envVarFileName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isString {
|
||||||
|
field.SetString(strings.TrimSpace(string(fileContent)))
|
||||||
|
} else {
|
||||||
|
field.SetBytes(fileContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,18 +17,19 @@ func TestParseEnvConfig(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("should parse valid SQLite config correctly", func(t *testing.T) {
|
t.Run("should parse valid SQLite config correctly", func(t *testing.T) {
|
||||||
EnvConfig = defaultConfig()
|
EnvConfig = defaultConfig()
|
||||||
t.Setenv("DB_PROVIDER", "sqlite")
|
t.Setenv("DB_PROVIDER", "SQLITE") // should be lowercased automatically
|
||||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||||
t.Setenv("APP_URL", "http://localhost:3000")
|
t.Setenv("APP_URL", "HTTP://LOCALHOST:3000")
|
||||||
|
|
||||||
err := parseEnvConfig()
|
err := parseEnvConfig()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider)
|
assert.Equal(t, DbProviderSqlite, EnvConfig.DbProvider)
|
||||||
|
assert.Equal(t, "http://localhost:3000", EnvConfig.AppURL)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("should parse valid Postgres config correctly", func(t *testing.T) {
|
t.Run("should parse valid Postgres config correctly", func(t *testing.T) {
|
||||||
EnvConfig = defaultConfig()
|
EnvConfig = defaultConfig()
|
||||||
t.Setenv("DB_PROVIDER", "postgres")
|
t.Setenv("DB_PROVIDER", "POSTGRES")
|
||||||
t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db")
|
t.Setenv("DB_CONNECTION_STRING", "postgres://user:pass@localhost/db")
|
||||||
t.Setenv("APP_URL", "https://example.com")
|
t.Setenv("APP_URL", "https://example.com")
|
||||||
|
|
||||||
@@ -51,7 +52,6 @@ func TestParseEnvConfig(t *testing.T) {
|
|||||||
t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) {
|
t.Run("should set default SQLite connection string when DB_CONNECTION_STRING is empty", func(t *testing.T) {
|
||||||
EnvConfig = defaultConfig()
|
EnvConfig = defaultConfig()
|
||||||
t.Setenv("DB_PROVIDER", "sqlite")
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
t.Setenv("DB_CONNECTION_STRING", "") // Explicitly empty
|
|
||||||
t.Setenv("APP_URL", "http://localhost:3000")
|
t.Setenv("APP_URL", "http://localhost:3000")
|
||||||
|
|
||||||
err := parseEnvConfig()
|
err := parseEnvConfig()
|
||||||
@@ -91,6 +91,28 @@ func TestParseEnvConfig(t *testing.T) {
|
|||||||
assert.ErrorContains(t, err, "APP_URL must not contain a path")
|
assert.ErrorContains(t, err, "APP_URL must not contain a path")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("should fail with invalid INTERNAL_APP_URL", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||||
|
t.Setenv("INTERNAL_APP_URL", "€://not-a-valid-url")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "INTERNAL_APP_URL is not a valid URL")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("should fail when INTERNAL_APP_URL contains path", func(t *testing.T) {
|
||||||
|
EnvConfig = defaultConfig()
|
||||||
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
|
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||||
|
t.Setenv("INTERNAL_APP_URL", "http://localhost:3000/path")
|
||||||
|
|
||||||
|
err := parseEnvConfig()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorContains(t, err, "INTERNAL_APP_URL must not contain a path")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("should default KEYS_STORAGE to 'file' when empty", func(t *testing.T) {
|
t.Run("should default KEYS_STORAGE to 'file' when empty", func(t *testing.T) {
|
||||||
EnvConfig = defaultConfig()
|
EnvConfig = defaultConfig()
|
||||||
t.Setenv("DB_PROVIDER", "sqlite")
|
t.Setenv("DB_PROVIDER", "sqlite")
|
||||||
@@ -170,25 +192,25 @@ func TestParseEnvConfig(t *testing.T) {
|
|||||||
t.Setenv("DB_PROVIDER", "postgres")
|
t.Setenv("DB_PROVIDER", "postgres")
|
||||||
t.Setenv("DB_CONNECTION_STRING", "postgres://test")
|
t.Setenv("DB_CONNECTION_STRING", "postgres://test")
|
||||||
t.Setenv("APP_URL", "https://prod.example.com")
|
t.Setenv("APP_URL", "https://prod.example.com")
|
||||||
t.Setenv("APP_ENV", "staging")
|
t.Setenv("APP_ENV", "STAGING")
|
||||||
t.Setenv("UPLOAD_PATH", "/custom/uploads")
|
t.Setenv("UPLOAD_PATH", "/custom/uploads")
|
||||||
t.Setenv("KEYS_PATH", "/custom/keys")
|
t.Setenv("KEYS_PATH", "/custom/keys")
|
||||||
t.Setenv("PORT", "8080")
|
t.Setenv("PORT", "8080")
|
||||||
t.Setenv("HOST", "127.0.0.1")
|
t.Setenv("HOST", "LOCALHOST")
|
||||||
t.Setenv("UNIX_SOCKET", "/tmp/app.sock")
|
t.Setenv("UNIX_SOCKET", "/tmp/app.sock")
|
||||||
t.Setenv("MAXMIND_LICENSE_KEY", "test-license")
|
t.Setenv("MAXMIND_LICENSE_KEY", "test-license")
|
||||||
t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb")
|
t.Setenv("GEOLITE_DB_PATH", "/custom/geolite.mmdb")
|
||||||
|
|
||||||
err := parseEnvConfig()
|
err := parseEnvConfig()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, "staging", EnvConfig.AppEnv)
|
assert.Equal(t, "staging", EnvConfig.AppEnv) // lowercased
|
||||||
assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath)
|
assert.Equal(t, "/custom/uploads", EnvConfig.UploadPath)
|
||||||
assert.Equal(t, "8080", EnvConfig.Port)
|
assert.Equal(t, "8080", EnvConfig.Port)
|
||||||
assert.Equal(t, "127.0.0.1", EnvConfig.Host)
|
assert.Equal(t, "localhost", EnvConfig.Host) // lowercased
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveFileBasedEnvVariables(t *testing.T) {
|
func TestPrepareEnvConfig_FileBasedAndToLower(t *testing.T) {
|
||||||
// Create temporary directory for test files
|
// Create temporary directory for test files
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
@@ -203,103 +225,34 @@ func TestResolveFileBasedEnvVariables(t *testing.T) {
|
|||||||
err = os.WriteFile(dbConnFile, []byte(dbConnContent), 0600)
|
err = os.WriteFile(dbConnFile, []byte(dbConnContent), 0600)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Create a binary file for testing binary data handling
|
|
||||||
binaryKeyFile := tempDir + "/binary_key.bin"
|
binaryKeyFile := tempDir + "/binary_key.bin"
|
||||||
binaryKeyContent := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10}
|
binaryKeyContent := []byte{0x01, 0x02, 0x03, 0x04}
|
||||||
err = os.WriteFile(binaryKeyFile, binaryKeyContent, 0600)
|
err = os.WriteFile(binaryKeyFile, binaryKeyContent, 0600)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("should read file content for fields with options:file tag", func(t *testing.T) {
|
t.Run("should process toLower and file options", func(t *testing.T) {
|
||||||
config := defaultConfig()
|
config := defaultConfig()
|
||||||
|
config.AppEnv = "STAGING"
|
||||||
|
config.Host = "LOCALHOST"
|
||||||
|
|
||||||
// Set environment variables pointing to files
|
|
||||||
t.Setenv("ENCRYPTION_KEY_FILE", encryptionKeyFile)
|
t.Setenv("ENCRYPTION_KEY_FILE", encryptionKeyFile)
|
||||||
t.Setenv("DB_CONNECTION_STRING_FILE", dbConnFile)
|
t.Setenv("DB_CONNECTION_STRING_FILE", dbConnFile)
|
||||||
|
|
||||||
err := resolveFileBasedEnvVariables(&config)
|
err := prepareEnvConfig(&config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify file contents were read correctly
|
assert.Equal(t, "staging", config.AppEnv)
|
||||||
|
assert.Equal(t, "localhost", config.Host)
|
||||||
assert.Equal(t, []byte(encryptionKeyContent), config.EncryptionKey)
|
assert.Equal(t, []byte(encryptionKeyContent), config.EncryptionKey)
|
||||||
assert.Equal(t, dbConnContent, config.DbConnectionString)
|
assert.Equal(t, dbConnContent, config.DbConnectionString)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("should skip fields without options:file tag", func(t *testing.T) {
|
|
||||||
config := defaultConfig()
|
|
||||||
originalAppURL := config.AppURL
|
|
||||||
|
|
||||||
// Set a file for a field that doesn't have options:file tag
|
|
||||||
t.Setenv("APP_URL_FILE", "/tmp/nonexistent.txt")
|
|
||||||
|
|
||||||
err := resolveFileBasedEnvVariables(&config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// AppURL should remain unchanged
|
|
||||||
assert.Equal(t, originalAppURL, config.AppURL)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("should skip non-string fields", func(t *testing.T) {
|
|
||||||
// This test verifies that non-string fields are skipped
|
|
||||||
// We test this indirectly by ensuring the function doesn't error
|
|
||||||
// when processing the actual EnvConfigSchema which has bool fields
|
|
||||||
config := defaultConfig()
|
|
||||||
|
|
||||||
err := resolveFileBasedEnvVariables(&config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("should skip when _FILE environment variable is not set", func(t *testing.T) {
|
|
||||||
config := defaultConfig()
|
|
||||||
originalEncryptionKey := config.EncryptionKey
|
|
||||||
|
|
||||||
// Don't set ENCRYPTION_KEY_FILE environment variable
|
|
||||||
|
|
||||||
err := resolveFileBasedEnvVariables(&config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// EncryptionKey should remain unchanged
|
|
||||||
assert.Equal(t, originalEncryptionKey, config.EncryptionKey)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("should handle multiple file-based variables simultaneously", func(t *testing.T) {
|
|
||||||
config := defaultConfig()
|
|
||||||
|
|
||||||
// Set multiple file environment variables
|
|
||||||
t.Setenv("ENCRYPTION_KEY_FILE", encryptionKeyFile)
|
|
||||||
t.Setenv("DB_CONNECTION_STRING_FILE", dbConnFile)
|
|
||||||
|
|
||||||
err := resolveFileBasedEnvVariables(&config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// All should be resolved correctly
|
|
||||||
assert.Equal(t, []byte(encryptionKeyContent), config.EncryptionKey)
|
|
||||||
assert.Equal(t, dbConnContent, config.DbConnectionString)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("should handle mixed file and non-file environment variables", func(t *testing.T) {
|
|
||||||
config := defaultConfig()
|
|
||||||
|
|
||||||
// Set both file and non-file environment variables
|
|
||||||
t.Setenv("ENCRYPTION_KEY_FILE", encryptionKeyFile)
|
|
||||||
|
|
||||||
err := resolveFileBasedEnvVariables(&config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// File-based should be resolved, others should remain as set by env parser
|
|
||||||
assert.Equal(t, []byte(encryptionKeyContent), config.EncryptionKey)
|
|
||||||
assert.Equal(t, "http://localhost:1411", config.AppURL)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("should handle binary data correctly", func(t *testing.T) {
|
t.Run("should handle binary data correctly", func(t *testing.T) {
|
||||||
config := defaultConfig()
|
config := defaultConfig()
|
||||||
|
|
||||||
// Set environment variable pointing to binary file
|
|
||||||
t.Setenv("ENCRYPTION_KEY_FILE", binaryKeyFile)
|
t.Setenv("ENCRYPTION_KEY_FILE", binaryKeyFile)
|
||||||
|
|
||||||
err := resolveFileBasedEnvVariables(&config)
|
err := prepareEnvConfig(&config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify binary data was read correctly without corruption
|
|
||||||
assert.Equal(t, binaryKeyContent, config.EncryptionKey)
|
assert.Equal(t, binaryKeyContent, config.EncryptionKey)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -378,3 +378,13 @@ func (e *ClientIdAlreadyExistsError) Error() string {
|
|||||||
func (e *ClientIdAlreadyExistsError) HttpStatusCode() int {
|
func (e *ClientIdAlreadyExistsError) HttpStatusCode() int {
|
||||||
return http.StatusBadRequest
|
return http.StatusBadRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UserEmailNotSetError struct{}
|
||||||
|
|
||||||
|
func (e *UserEmailNotSetError) Error() string {
|
||||||
|
return "The user does not have an email address set"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *UserEmailNotSetError) HttpStatusCode() int {
|
||||||
|
return http.StatusBadRequest
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,14 +3,12 @@ package controller
|
|||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/middleware"
|
"github.com/pocket-id/pocket-id/backend/internal/middleware"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/service"
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewAppConfigController creates a new controller for application configuration endpoints
|
// NewAppConfigController creates a new controller for application configuration endpoints
|
||||||
@@ -34,13 +32,6 @@ func NewAppConfigController(
|
|||||||
group.GET("/application-configuration/all", authMiddleware.Add(), acc.listAllAppConfigHandler)
|
group.GET("/application-configuration/all", authMiddleware.Add(), acc.listAllAppConfigHandler)
|
||||||
group.PUT("/application-configuration", authMiddleware.Add(), acc.updateAppConfigHandler)
|
group.PUT("/application-configuration", authMiddleware.Add(), acc.updateAppConfigHandler)
|
||||||
|
|
||||||
group.GET("/application-configuration/logo", acc.getLogoHandler)
|
|
||||||
group.GET("/application-configuration/background-image", acc.getBackgroundImageHandler)
|
|
||||||
group.GET("/application-configuration/favicon", acc.getFaviconHandler)
|
|
||||||
group.PUT("/application-configuration/logo", authMiddleware.Add(), acc.updateLogoHandler)
|
|
||||||
group.PUT("/application-configuration/favicon", authMiddleware.Add(), acc.updateFaviconHandler)
|
|
||||||
group.PUT("/application-configuration/background-image", authMiddleware.Add(), acc.updateBackgroundImageHandler)
|
|
||||||
|
|
||||||
group.POST("/application-configuration/test-email", authMiddleware.Add(), acc.testEmailHandler)
|
group.POST("/application-configuration/test-email", authMiddleware.Add(), acc.testEmailHandler)
|
||||||
group.POST("/application-configuration/sync-ldap", authMiddleware.Add(), acc.syncLdapHandler)
|
group.POST("/application-configuration/sync-ldap", authMiddleware.Add(), acc.syncLdapHandler)
|
||||||
}
|
}
|
||||||
@@ -129,147 +120,6 @@ func (acc *AppConfigController) updateAppConfigHandler(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, configVariablesDto)
|
c.JSON(http.StatusOK, configVariablesDto)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getLogoHandler godoc
|
|
||||||
// @Summary Get logo image
|
|
||||||
// @Description Get the logo image for the application
|
|
||||||
// @Tags Application Configuration
|
|
||||||
// @Param light query boolean false "Light mode logo (true) or dark mode logo (false)"
|
|
||||||
// @Produce image/png
|
|
||||||
// @Produce image/jpeg
|
|
||||||
// @Produce image/svg+xml
|
|
||||||
// @Success 200 {file} binary "Logo image"
|
|
||||||
// @Router /api/application-configuration/logo [get]
|
|
||||||
func (acc *AppConfigController) getLogoHandler(c *gin.Context) {
|
|
||||||
dbConfig := acc.appConfigService.GetDbConfig()
|
|
||||||
|
|
||||||
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
|
|
||||||
|
|
||||||
var imageName, imageType string
|
|
||||||
if lightLogo {
|
|
||||||
imageName = "logoLight"
|
|
||||||
imageType = dbConfig.LogoLightImageType.Value
|
|
||||||
} else {
|
|
||||||
imageName = "logoDark"
|
|
||||||
imageType = dbConfig.LogoDarkImageType.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
acc.getImage(c, imageName, imageType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getFaviconHandler godoc
|
|
||||||
// @Summary Get favicon
|
|
||||||
// @Description Get the favicon for the application
|
|
||||||
// @Tags Application Configuration
|
|
||||||
// @Produce image/x-icon
|
|
||||||
// @Success 200 {file} binary "Favicon image"
|
|
||||||
// @Router /api/application-configuration/favicon [get]
|
|
||||||
func (acc *AppConfigController) getFaviconHandler(c *gin.Context) {
|
|
||||||
acc.getImage(c, "favicon", "ico")
|
|
||||||
}
|
|
||||||
|
|
||||||
// getBackgroundImageHandler godoc
|
|
||||||
// @Summary Get background image
|
|
||||||
// @Description Get the background image for the application
|
|
||||||
// @Tags Application Configuration
|
|
||||||
// @Produce image/png
|
|
||||||
// @Produce image/jpeg
|
|
||||||
// @Success 200 {file} binary "Background image"
|
|
||||||
// @Router /api/application-configuration/background-image [get]
|
|
||||||
func (acc *AppConfigController) getBackgroundImageHandler(c *gin.Context) {
|
|
||||||
imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
|
|
||||||
acc.getImage(c, "background", imageType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateLogoHandler godoc
|
|
||||||
// @Summary Update logo
|
|
||||||
// @Description Update the application logo
|
|
||||||
// @Tags Application Configuration
|
|
||||||
// @Accept multipart/form-data
|
|
||||||
// @Param light query boolean false "Light mode logo (true) or dark mode logo (false)"
|
|
||||||
// @Param file formData file true "Logo image file"
|
|
||||||
// @Success 204 "No Content"
|
|
||||||
// @Router /api/application-configuration/logo [put]
|
|
||||||
func (acc *AppConfigController) updateLogoHandler(c *gin.Context) {
|
|
||||||
dbConfig := acc.appConfigService.GetDbConfig()
|
|
||||||
|
|
||||||
lightLogo, _ := strconv.ParseBool(c.DefaultQuery("light", "true"))
|
|
||||||
|
|
||||||
var imageName, imageType string
|
|
||||||
if lightLogo {
|
|
||||||
imageName = "logoLight"
|
|
||||||
imageType = dbConfig.LogoLightImageType.Value
|
|
||||||
} else {
|
|
||||||
imageName = "logoDark"
|
|
||||||
imageType = dbConfig.LogoDarkImageType.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
acc.updateImage(c, imageName, imageType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateFaviconHandler godoc
|
|
||||||
// @Summary Update favicon
|
|
||||||
// @Description Update the application favicon
|
|
||||||
// @Tags Application Configuration
|
|
||||||
// @Accept multipart/form-data
|
|
||||||
// @Param file formData file true "Favicon file (.ico)"
|
|
||||||
// @Success 204 "No Content"
|
|
||||||
// @Router /api/application-configuration/favicon [put]
|
|
||||||
func (acc *AppConfigController) updateFaviconHandler(c *gin.Context) {
|
|
||||||
file, err := c.FormFile("file")
|
|
||||||
if err != nil {
|
|
||||||
_ = c.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fileType := utils.GetFileExtension(file.Filename)
|
|
||||||
if fileType != "ico" {
|
|
||||||
_ = c.Error(&common.WrongFileTypeError{ExpectedFileType: ".ico"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
acc.updateImage(c, "favicon", "ico")
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateBackgroundImageHandler godoc
|
|
||||||
// @Summary Update background image
|
|
||||||
// @Description Update the application background image
|
|
||||||
// @Tags Application Configuration
|
|
||||||
// @Accept multipart/form-data
|
|
||||||
// @Param file formData file true "Background image file"
|
|
||||||
// @Success 204 "No Content"
|
|
||||||
// @Router /api/application-configuration/background-image [put]
|
|
||||||
func (acc *AppConfigController) updateBackgroundImageHandler(c *gin.Context) {
|
|
||||||
imageType := acc.appConfigService.GetDbConfig().BackgroundImageType.Value
|
|
||||||
acc.updateImage(c, "background", imageType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getImage is a helper function to serve image files
|
|
||||||
func (acc *AppConfigController) getImage(c *gin.Context, name string, imageType string) {
|
|
||||||
imagePath := common.EnvConfig.UploadPath + "/application-images/" + name + "." + imageType
|
|
||||||
mimeType := utils.GetImageMimeType(imageType)
|
|
||||||
|
|
||||||
c.Header("Content-Type", mimeType)
|
|
||||||
|
|
||||||
utils.SetCacheControlHeader(c, 15*time.Minute, 24*time.Hour)
|
|
||||||
c.File(imagePath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateImage is a helper function to update image files
|
|
||||||
func (acc *AppConfigController) updateImage(c *gin.Context, imageName string, oldImageType string) {
|
|
||||||
file, err := c.FormFile("file")
|
|
||||||
if err != nil {
|
|
||||||
_ = c.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = acc.appConfigService.UpdateImage(c.Request.Context(), file, imageName, oldImageType)
|
|
||||||
if err != nil {
|
|
||||||
_ = c.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(http.StatusNoContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
// syncLdapHandler godoc
|
// syncLdapHandler godoc
|
||||||
// @Summary Synchronize LDAP
|
// @Summary Synchronize LDAP
|
||||||
// @Description Manually trigger LDAP synchronization
|
// @Description Manually trigger LDAP synchronization
|
||||||
|
|||||||
173
backend/internal/controller/app_images_controller.go
Normal file
173
backend/internal/controller/app_images_controller.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/middleware"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewAppImagesController(
|
||||||
|
group *gin.RouterGroup,
|
||||||
|
authMiddleware *middleware.AuthMiddleware,
|
||||||
|
appImagesService *service.AppImagesService,
|
||||||
|
) {
|
||||||
|
controller := &AppImagesController{
|
||||||
|
appImagesService: appImagesService,
|
||||||
|
}
|
||||||
|
|
||||||
|
group.GET("/application-images/logo", controller.getLogoHandler)
|
||||||
|
group.GET("/application-images/background", controller.getBackgroundImageHandler)
|
||||||
|
group.GET("/application-images/favicon", controller.getFaviconHandler)
|
||||||
|
|
||||||
|
group.PUT("/application-images/logo", authMiddleware.Add(), controller.updateLogoHandler)
|
||||||
|
group.PUT("/application-images/background", authMiddleware.Add(), controller.updateBackgroundImageHandler)
|
||||||
|
group.PUT("/application-images/favicon", authMiddleware.Add(), controller.updateFaviconHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
type AppImagesController struct {
|
||||||
|
appImagesService *service.AppImagesService
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLogoHandler godoc
|
||||||
|
// @Summary Get logo image
|
||||||
|
// @Description Get the logo image for the application
|
||||||
|
// @Tags Application Images
|
||||||
|
// @Param light query boolean false "Light mode logo (true) or dark mode logo (false)"
|
||||||
|
// @Produce image/png
|
||||||
|
// @Produce image/jpeg
|
||||||
|
// @Produce image/svg+xml
|
||||||
|
// @Success 200 {file} binary "Logo image"
|
||||||
|
// @Router /api/application-images/logo [get]
|
||||||
|
func (c *AppImagesController) getLogoHandler(ctx *gin.Context) {
|
||||||
|
lightLogo, _ := strconv.ParseBool(ctx.DefaultQuery("light", "true"))
|
||||||
|
imageName := "logoLight"
|
||||||
|
if !lightLogo {
|
||||||
|
imageName = "logoDark"
|
||||||
|
}
|
||||||
|
|
||||||
|
c.getImage(ctx, imageName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getBackgroundImageHandler godoc
|
||||||
|
// @Summary Get background image
|
||||||
|
// @Description Get the background image for the application
|
||||||
|
// @Tags Application Images
|
||||||
|
// @Produce image/png
|
||||||
|
// @Produce image/jpeg
|
||||||
|
// @Success 200 {file} binary "Background image"
|
||||||
|
// @Router /api/application-images/background [get]
|
||||||
|
func (c *AppImagesController) getBackgroundImageHandler(ctx *gin.Context) {
|
||||||
|
c.getImage(ctx, "background")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFaviconHandler godoc
|
||||||
|
// @Summary Get favicon
|
||||||
|
// @Description Get the favicon for the application
|
||||||
|
// @Tags Application Images
|
||||||
|
// @Produce image/x-icon
|
||||||
|
// @Success 200 {file} binary "Favicon image"
|
||||||
|
// @Router /api/application-images/favicon [get]
|
||||||
|
func (c *AppImagesController) getFaviconHandler(ctx *gin.Context) {
|
||||||
|
c.getImage(ctx, "favicon")
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateLogoHandler godoc
|
||||||
|
// @Summary Update logo
|
||||||
|
// @Description Update the application logo
|
||||||
|
// @Tags Application Images
|
||||||
|
// @Accept multipart/form-data
|
||||||
|
// @Param light query boolean false "Light mode logo (true) or dark mode logo (false)"
|
||||||
|
// @Param file formData file true "Logo image file"
|
||||||
|
// @Success 204 "No Content"
|
||||||
|
// @Router /api/application-images/logo [put]
|
||||||
|
func (c *AppImagesController) updateLogoHandler(ctx *gin.Context) {
|
||||||
|
file, err := ctx.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
_ = ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lightLogo, _ := strconv.ParseBool(ctx.DefaultQuery("light", "true"))
|
||||||
|
imageName := "logoLight"
|
||||||
|
if !lightLogo {
|
||||||
|
imageName = "logoDark"
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.appImagesService.UpdateImage(file, imageName); err != nil {
|
||||||
|
_ = ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateBackgroundImageHandler godoc
|
||||||
|
// @Summary Update background image
|
||||||
|
// @Description Update the application background image
|
||||||
|
// @Tags Application Images
|
||||||
|
// @Accept multipart/form-data
|
||||||
|
// @Param file formData file true "Background image file"
|
||||||
|
// @Success 204 "No Content"
|
||||||
|
// @Router /api/application-images/background [put]
|
||||||
|
func (c *AppImagesController) updateBackgroundImageHandler(ctx *gin.Context) {
|
||||||
|
file, err := ctx.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
_ = ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.appImagesService.UpdateImage(file, "background"); err != nil {
|
||||||
|
_ = ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateFaviconHandler godoc
|
||||||
|
// @Summary Update favicon
|
||||||
|
// @Description Update the application favicon
|
||||||
|
// @Tags Application Images
|
||||||
|
// @Accept multipart/form-data
|
||||||
|
// @Param file formData file true "Favicon file (.ico)"
|
||||||
|
// @Success 204 "No Content"
|
||||||
|
// @Router /api/application-images/favicon [put]
|
||||||
|
func (c *AppImagesController) updateFaviconHandler(ctx *gin.Context) {
|
||||||
|
file, err := ctx.FormFile("file")
|
||||||
|
if err != nil {
|
||||||
|
_ = ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fileType := utils.GetFileExtension(file.Filename)
|
||||||
|
if fileType != "ico" {
|
||||||
|
_ = ctx.Error(&common.WrongFileTypeError{ExpectedFileType: ".ico"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.appImagesService.UpdateImage(file, "favicon"); err != nil {
|
||||||
|
_ = ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AppImagesController) getImage(ctx *gin.Context, name string) {
|
||||||
|
imagePath, mimeType, err := c.appImagesService.GetImage(name)
|
||||||
|
if err != nil {
|
||||||
|
_ = ctx.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Header("Content-Type", mimeType)
|
||||||
|
utils.SetCacheControlHeader(ctx, 15*time.Minute, 24*time.Hour)
|
||||||
|
ctx.File(imagePath)
|
||||||
|
}
|
||||||
@@ -828,7 +828,7 @@ func (oc *OidcController) getClientPreviewHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, scopes)
|
preview, err := oc.oidcService.GetClientPreview(c.Request.Context(), clientID, userID, strings.Split(scopes, " "))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
|
|||||||
40
backend/internal/controller/version_controller.go
Normal file
40
backend/internal/controller/version_controller.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/service"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewVersionController registers version-related routes.
|
||||||
|
func NewVersionController(group *gin.RouterGroup, versionService *service.VersionService) {
|
||||||
|
vc := &VersionController{versionService: versionService}
|
||||||
|
group.GET("/version/latest", vc.getLatestVersionHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
type VersionController struct {
|
||||||
|
versionService *service.VersionService
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLatestVersionHandler godoc
|
||||||
|
// @Summary Get latest available version of Pocket ID
|
||||||
|
// @Tags Version
|
||||||
|
// @Produce json
|
||||||
|
// @Success 200 {object} map[string]string "Latest version information"
|
||||||
|
// @Router /api/version/latest [get]
|
||||||
|
func (vc *VersionController) getLatestVersionHandler(c *gin.Context) {
|
||||||
|
tag, err := vc.versionService.GetLatestVersion(c.Request.Context())
|
||||||
|
if err != nil {
|
||||||
|
_ = c.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
utils.SetCacheControlHeader(c, 5*time.Minute, 15*time.Minute)
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"latestVersion": tag,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -67,6 +67,9 @@ func (wkc *WellKnownController) openIDConfigurationHandler(c *gin.Context) {
|
|||||||
|
|
||||||
func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
|
func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
|
||||||
appUrl := common.EnvConfig.AppURL
|
appUrl := common.EnvConfig.AppURL
|
||||||
|
|
||||||
|
internalAppUrl := common.EnvConfig.InternalAppURL
|
||||||
|
|
||||||
alg, err := wkc.jwtService.GetKeyAlg()
|
alg, err := wkc.jwtService.GetKeyAlg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get key algorithm: %w", err)
|
return nil, fmt.Errorf("failed to get key algorithm: %w", err)
|
||||||
@@ -74,13 +77,13 @@ func (wkc *WellKnownController) computeOIDCConfiguration() ([]byte, error) {
|
|||||||
config := map[string]any{
|
config := map[string]any{
|
||||||
"issuer": appUrl,
|
"issuer": appUrl,
|
||||||
"authorization_endpoint": appUrl + "/authorize",
|
"authorization_endpoint": appUrl + "/authorize",
|
||||||
"token_endpoint": appUrl + "/api/oidc/token",
|
"token_endpoint": internalAppUrl + "/api/oidc/token",
|
||||||
"userinfo_endpoint": appUrl + "/api/oidc/userinfo",
|
"userinfo_endpoint": internalAppUrl + "/api/oidc/userinfo",
|
||||||
"end_session_endpoint": appUrl + "/api/oidc/end-session",
|
"end_session_endpoint": appUrl + "/api/oidc/end-session",
|
||||||
"introspection_endpoint": appUrl + "/api/oidc/introspect",
|
"introspection_endpoint": internalAppUrl + "/api/oidc/introspect",
|
||||||
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
|
"device_authorization_endpoint": appUrl + "/api/oidc/device/authorize",
|
||||||
"jwks_uri": appUrl + "/.well-known/jwks.json",
|
"jwks_uri": internalAppUrl + "/.well-known/jwks.json",
|
||||||
"grant_types_supported": []string{service.GrantTypeAuthorizationCode, service.GrantTypeRefreshToken, service.GrantTypeDeviceCode},
|
"grant_types_supported": []string{service.GrantTypeAuthorizationCode, service.GrantTypeRefreshToken, service.GrantTypeDeviceCode, service.GrantTypeClientCredentials},
|
||||||
"scopes_supported": []string{"openid", "profile", "email", "groups"},
|
"scopes_supported": []string{"openid", "profile", "email", "groups"},
|
||||||
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
|
"claims_supported": []string{"sub", "given_name", "family_name", "name", "email", "email_verified", "preferred_username", "picture", "groups"},
|
||||||
"response_types_supported": []string{"code", "id_token"},
|
"response_types_supported": []string{"code", "id_token"},
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type AppConfigUpdateDto struct {
|
|||||||
SignupDefaultUserGroupIDs string `json:"signupDefaultUserGroupIDs" binding:"omitempty,json"`
|
SignupDefaultUserGroupIDs string `json:"signupDefaultUserGroupIDs" binding:"omitempty,json"`
|
||||||
SignupDefaultCustomClaims string `json:"signupDefaultCustomClaims" binding:"omitempty,json"`
|
SignupDefaultCustomClaims string `json:"signupDefaultCustomClaims" binding:"omitempty,json"`
|
||||||
AccentColor string `json:"accentColor"`
|
AccentColor string `json:"accentColor"`
|
||||||
|
RequireUserEmail string `json:"requireUserEmail" binding:"required"`
|
||||||
SmtpHost string `json:"smtpHost"`
|
SmtpHost string `json:"smtpHost"`
|
||||||
SmtpPort string `json:"smtpPort"`
|
SmtpPort string `json:"smtpPort"`
|
||||||
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
|
SmtpFrom string `json:"smtpFrom" binding:"omitempty,email"`
|
||||||
@@ -41,6 +42,7 @@ type AppConfigUpdateDto struct {
|
|||||||
LdapAttributeUserEmail string `json:"ldapAttributeUserEmail"`
|
LdapAttributeUserEmail string `json:"ldapAttributeUserEmail"`
|
||||||
LdapAttributeUserFirstName string `json:"ldapAttributeUserFirstName"`
|
LdapAttributeUserFirstName string `json:"ldapAttributeUserFirstName"`
|
||||||
LdapAttributeUserLastName string `json:"ldapAttributeUserLastName"`
|
LdapAttributeUserLastName string `json:"ldapAttributeUserLastName"`
|
||||||
|
LdapAttributeUserDisplayName string `json:"ldapAttributeUserDisplayName"`
|
||||||
LdapAttributeUserProfilePicture string `json:"ldapAttributeUserProfilePicture"`
|
LdapAttributeUserProfilePicture string `json:"ldapAttributeUserProfilePicture"`
|
||||||
LdapAttributeGroupMember string `json:"ldapAttributeGroupMember"`
|
LdapAttributeGroupMember string `json:"ldapAttributeGroupMember"`
|
||||||
LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"`
|
LdapAttributeGroupUniqueIdentifier string `json:"ldapAttributeGroupUniqueIdentifier"`
|
||||||
|
|||||||
@@ -31,13 +31,15 @@ type OidcClientWithAllowedGroupsCountDto struct {
|
|||||||
|
|
||||||
type OidcClientUpdateDto struct {
|
type OidcClientUpdateDto struct {
|
||||||
Name string `json:"name" binding:"required,max=50" unorm:"nfc"`
|
Name string `json:"name" binding:"required,max=50" unorm:"nfc"`
|
||||||
CallbackURLs []string `json:"callbackURLs"`
|
CallbackURLs []string `json:"callbackURLs" binding:"omitempty,dive,callback_url"`
|
||||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
|
LogoutCallbackURLs []string `json:"logoutCallbackURLs" binding:"omitempty,dive,callback_url"`
|
||||||
IsPublic bool `json:"isPublic"`
|
IsPublic bool `json:"isPublic"`
|
||||||
PkceEnabled bool `json:"pkceEnabled"`
|
PkceEnabled bool `json:"pkceEnabled"`
|
||||||
RequiresReauthentication bool `json:"requiresReauthentication"`
|
RequiresReauthentication bool `json:"requiresReauthentication"`
|
||||||
Credentials OidcClientCredentialsDto `json:"credentials"`
|
Credentials OidcClientCredentialsDto `json:"credentials"`
|
||||||
LaunchURL *string `json:"launchURL" binding:"omitempty,url"`
|
LaunchURL *string `json:"launchURL" binding:"omitempty,url"`
|
||||||
|
HasLogo bool `json:"hasLogo"`
|
||||||
|
LogoURL *string `json:"logoUrl"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OidcClientCreateDto struct {
|
type OidcClientCreateDto struct {
|
||||||
@@ -87,6 +89,7 @@ type OidcCreateTokensDto struct {
|
|||||||
RefreshToken string `form:"refresh_token"`
|
RefreshToken string `form:"refresh_token"`
|
||||||
ClientAssertion string `form:"client_assertion"`
|
ClientAssertion string `form:"client_assertion"`
|
||||||
ClientAssertionType string `form:"client_assertion_type"`
|
ClientAssertionType string `form:"client_assertion_type"`
|
||||||
|
Resource string `form:"resource"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OidcIntrospectDto struct {
|
type OidcIntrospectDto struct {
|
||||||
|
|||||||
@@ -1,15 +1,19 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin/binding"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserDto struct {
|
type UserDto struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Email string `json:"email" `
|
Email *string `json:"email" `
|
||||||
FirstName string `json:"firstName"`
|
FirstName string `json:"firstName"`
|
||||||
LastName string `json:"lastName"`
|
LastName *string `json:"lastName"`
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
IsAdmin bool `json:"isAdmin"`
|
IsAdmin bool `json:"isAdmin"`
|
||||||
Locale *string `json:"locale"`
|
Locale *string `json:"locale"`
|
||||||
CustomClaims []CustomClaimDto `json:"customClaims"`
|
CustomClaims []CustomClaimDto `json:"customClaims"`
|
||||||
@@ -19,14 +23,26 @@ type UserDto struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type UserCreateDto struct {
|
type UserCreateDto struct {
|
||||||
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
|
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
|
||||||
Email string `json:"email" binding:"required,email" unorm:"nfc"`
|
Email *string `json:"email" binding:"omitempty,email" unorm:"nfc"`
|
||||||
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
|
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
|
||||||
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
|
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
|
||||||
IsAdmin bool `json:"isAdmin"`
|
DisplayName string `json:"displayName" binding:"required,min=1,max=100" unorm:"nfc"`
|
||||||
Locale *string `json:"locale"`
|
IsAdmin bool `json:"isAdmin"`
|
||||||
Disabled bool `json:"disabled"`
|
Locale *string `json:"locale"`
|
||||||
LdapID string `json:"-"`
|
Disabled bool `json:"disabled"`
|
||||||
|
LdapID string `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UserCreateDto) Validate() error {
|
||||||
|
e, ok := binding.Validator.Engine().(interface {
|
||||||
|
Struct(s any) error
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
return errors.New("validator does not implement the expected interface")
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.Struct(u)
|
||||||
}
|
}
|
||||||
|
|
||||||
type OneTimeAccessTokenCreateDto struct {
|
type OneTimeAccessTokenCreateDto struct {
|
||||||
@@ -48,9 +64,9 @@ type UserUpdateUserGroupDto struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SignUpDto struct {
|
type SignUpDto struct {
|
||||||
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
|
Username string `json:"username" binding:"required,username,min=2,max=50" unorm:"nfc"`
|
||||||
Email string `json:"email" binding:"required,email" unorm:"nfc"`
|
Email *string `json:"email" binding:"omitempty,email" unorm:"nfc"`
|
||||||
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
|
FirstName string `json:"firstName" binding:"required,min=1,max=50" unorm:"nfc"`
|
||||||
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
|
LastName string `json:"lastName" binding:"max=50" unorm:"nfc"`
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
}
|
}
|
||||||
|
|||||||
105
backend/internal/dto/user_dto_test.go
Normal file
105
backend/internal/dto/user_dto_test.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUserCreateDto_Validate(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
input UserCreateDto
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid input",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Username: "testuser",
|
||||||
|
Email: utils.Ptr("test@example.com"),
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "Doe",
|
||||||
|
DisplayName: "John Doe",
|
||||||
|
},
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing username",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Email: utils.Ptr("test@example.com"),
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "Doe",
|
||||||
|
DisplayName: "John Doe",
|
||||||
|
},
|
||||||
|
wantErr: "Field validation for 'Username' failed on the 'required' tag",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing display name",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Email: utils.Ptr("test@example.com"),
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "Doe",
|
||||||
|
},
|
||||||
|
wantErr: "Field validation for 'DisplayName' failed on the 'required' tag",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "username contains invalid characters",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Username: "test/ser",
|
||||||
|
Email: utils.Ptr("test@example.com"),
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "Doe",
|
||||||
|
DisplayName: "John Doe",
|
||||||
|
},
|
||||||
|
wantErr: "Field validation for 'Username' failed on the 'username' tag",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid email",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Username: "testuser",
|
||||||
|
Email: utils.Ptr("not-an-email"),
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "Doe",
|
||||||
|
DisplayName: "John Doe",
|
||||||
|
},
|
||||||
|
wantErr: "Field validation for 'Email' failed on the 'email' tag",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "first name too short",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Username: "testuser",
|
||||||
|
Email: utils.Ptr("test@example.com"),
|
||||||
|
FirstName: "",
|
||||||
|
LastName: "Doe",
|
||||||
|
DisplayName: "John Doe",
|
||||||
|
},
|
||||||
|
wantErr: "Field validation for 'FirstName' failed on the 'required' tag",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "last name too long",
|
||||||
|
input: UserCreateDto{
|
||||||
|
Username: "testuser",
|
||||||
|
Email: utils.Ptr("test@example.com"),
|
||||||
|
FirstName: "John",
|
||||||
|
LastName: "abcdfghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz",
|
||||||
|
DisplayName: "John Doe",
|
||||||
|
},
|
||||||
|
wantErr: "Field validation for 'LastName' failed on the 'max' tag",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := tc.input.Validate()
|
||||||
|
|
||||||
|
if tc.wantErr == "" {
|
||||||
|
require.NoError(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorContains(t, err, tc.wantErr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin/binding"
|
||||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,6 +42,17 @@ type UserGroupCreateDto struct {
|
|||||||
LdapID string `json:"-"`
|
LdapID string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g UserGroupCreateDto) Validate() error {
|
||||||
|
e, ok := binding.Validator.Engine().(interface {
|
||||||
|
Struct(s any) error
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
return errors.New("validator does not implement the expected interface")
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.Struct(g)
|
||||||
|
}
|
||||||
|
|
||||||
type UserGroupUpdateUsersDto struct {
|
type UserGroupUpdateUsersDto struct {
|
||||||
UserIDs []string `json:"userIds" binding:"required"`
|
UserIDs []string `json:"userIds" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
@@ -10,43 +12,74 @@ import (
|
|||||||
"github.com/go-playground/validator/v10"
|
"github.com/go-playground/validator/v10"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
||||||
|
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
||||||
|
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
|
||||||
|
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
|
||||||
|
|
||||||
|
var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$")
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
v := binding.Validator.Engine().(*validator.Validate)
|
v := binding.Validator.Engine().(*validator.Validate)
|
||||||
|
|
||||||
// [a-zA-Z0-9] : The username must start with an alphanumeric character
|
|
||||||
// [a-zA-Z0-9_.@-]* : The rest of the username can contain alphanumeric characters, dots, underscores, hyphens, and "@" symbols
|
|
||||||
// [a-zA-Z0-9]$ : The username must end with an alphanumeric character
|
|
||||||
var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9][a-zA-Z0-9_.@-]*[a-zA-Z0-9]$")
|
|
||||||
|
|
||||||
var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$")
|
|
||||||
|
|
||||||
// Maximum allowed value for TTLs
|
// Maximum allowed value for TTLs
|
||||||
const maxTTL = 31 * 24 * time.Hour
|
const maxTTL = 31 * 24 * time.Hour
|
||||||
|
|
||||||
// Errors here are development-time ones
|
if err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
||||||
err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
return ValidateUsername(fl.Field().String())
|
||||||
return validateUsernameRegex.MatchString(fl.Field().String())
|
}); err != nil {
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
panic("Failed to register custom validation for username: " + err.Error())
|
panic("Failed to register custom validation for username: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool {
|
if err := v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool {
|
||||||
return validateClientIDRegex.MatchString(fl.Field().String())
|
return ValidateClientID(fl.Field().String())
|
||||||
})
|
}); err != nil {
|
||||||
if err != nil {
|
|
||||||
panic("Failed to register custom validation for client_id: " + err.Error())
|
panic("Failed to register custom validation for client_id: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
err = v.RegisterValidation("ttl", func(fl validator.FieldLevel) bool {
|
if err := v.RegisterValidation("ttl", func(fl validator.FieldLevel) bool {
|
||||||
ttl, ok := fl.Field().Interface().(utils.JSONDuration)
|
ttl, ok := fl.Field().Interface().(utils.JSONDuration)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// Allow zero, which means the field wasn't set
|
// Allow zero, which means the field wasn't set
|
||||||
return ttl.Duration == 0 || ttl.Duration > time.Second && ttl.Duration <= maxTTL
|
return ttl.Duration == 0 || (ttl.Duration > time.Second && ttl.Duration <= maxTTL)
|
||||||
})
|
}); err != nil {
|
||||||
if err != nil {
|
|
||||||
panic("Failed to register custom validation for ttl: " + err.Error())
|
panic("Failed to register custom validation for ttl: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := v.RegisterValidation("callback_url", func(fl validator.FieldLevel) bool {
|
||||||
|
return ValidateCallbackURL(fl.Field().String())
|
||||||
|
}); err != nil {
|
||||||
|
panic("Failed to register custom validation for callback_url: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateUsername validates username inputs
|
||||||
|
func ValidateUsername(username string) bool {
|
||||||
|
return validateUsernameRegex.MatchString(username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateClientID validates client ID inputs
|
||||||
|
func ValidateClientID(clientID string) bool {
|
||||||
|
return validateClientIDRegex.MatchString(clientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateCallbackURL validates callback URLs with support for wildcards
|
||||||
|
func ValidateCallbackURL(raw string) bool {
|
||||||
|
// Don't validate if it contains a wildcard
|
||||||
|
if strings.Contains(raw, "*") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !u.IsAbs() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
58
backend/internal/dto/validations_test.go
Normal file
58
backend/internal/dto/validations_test.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package dto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateUsername(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"valid simple", "user123", true},
|
||||||
|
{"valid with dot", "user.name", true},
|
||||||
|
{"valid with underscore", "user_name", true},
|
||||||
|
{"valid with hyphen", "user-name", true},
|
||||||
|
{"valid with at", "user@name", true},
|
||||||
|
{"starts with symbol", ".username", false},
|
||||||
|
{"ends with non-alphanumeric", "username-", false},
|
||||||
|
{"contains space", "user name", false},
|
||||||
|
{"empty", "", false},
|
||||||
|
{"only special chars", "-._@", false},
|
||||||
|
{"valid long", "a1234567890_b.c-d@e", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, ValidateUsername(tt.input))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateClientID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"valid simple", "client123", true},
|
||||||
|
{"valid with dot", "client.id", true},
|
||||||
|
{"valid with underscore", "client_id", true},
|
||||||
|
{"valid with hyphen", "client-id", true},
|
||||||
|
{"valid with all", "client.id-123_abc", true},
|
||||||
|
{"contains space", "client id", false},
|
||||||
|
{"contains at", "client@id", false},
|
||||||
|
{"empty", "", false},
|
||||||
|
{"only special chars", "-._", true},
|
||||||
|
{"invalid char", "client!id", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, ValidateClientID(tt.input))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -37,7 +37,7 @@ func (j *ApiKeyEmailJobs) checkAndNotifyExpiringApiKeys(ctx context.Context) err
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range apiKeys {
|
for _, key := range apiKeys {
|
||||||
if key.User.Email == "" {
|
if key.User.Email == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = j.apiKeyService.SendApiKeyExpiringSoonEmail(ctx, key)
|
err = j.apiKeyService.SendApiKeyExpiringSoonEmail(ctx, key)
|
||||||
|
|||||||
53
backend/internal/middleware/csp_middleware.go
Normal file
53
backend/internal/middleware/csp_middleware.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CspMiddleware sets a Content Security Policy header and, when possible,
|
||||||
|
// includes a per-request nonce for inline scripts.
|
||||||
|
type CspMiddleware struct{}
|
||||||
|
|
||||||
|
func NewCspMiddleware() *CspMiddleware { return &CspMiddleware{} }
|
||||||
|
|
||||||
|
// GetCSPNonce returns the CSP nonce generated for this request, if any.
|
||||||
|
func GetCSPNonce(c *gin.Context) string {
|
||||||
|
if v, ok := c.Get("csp_nonce"); ok {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CspMiddleware) Add() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// Generate a random base64 nonce for this request
|
||||||
|
nonce := generateNonce()
|
||||||
|
c.Set("csp_nonce", nonce)
|
||||||
|
|
||||||
|
csp := "default-src 'self'; " +
|
||||||
|
"base-uri 'self'; " +
|
||||||
|
"object-src 'none'; " +
|
||||||
|
"frame-ancestors 'none'; " +
|
||||||
|
"form-action 'self'; " +
|
||||||
|
"img-src * blob:;" +
|
||||||
|
"font-src 'self'; " +
|
||||||
|
"style-src 'self' 'unsafe-inline'; " +
|
||||||
|
"script-src 'self' 'nonce-" + nonce + "'"
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Security-Policy", csp)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateNonce() string {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "" // if generation fails, return empty; policy will omit nonce
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
@@ -77,7 +77,7 @@ func handleValidationError(validationErrors validator.ValidationErrors) string {
|
|||||||
case "email":
|
case "email":
|
||||||
errorMessage = fmt.Sprintf("%s must be a valid email address", fieldName)
|
errorMessage = fmt.Sprintf("%s must be a valid email address", fieldName)
|
||||||
case "username":
|
case "username":
|
||||||
errorMessage = fmt.Sprintf("%s must only contain lowercase letters, numbers, underscores, dots, hyphens, and '@' symbols and not start or end with a special character", fieldName)
|
errorMessage = fmt.Sprintf("%s must only contain letters, numbers, underscores, dots, hyphens, and '@' symbols and not start or end with a special character", fieldName)
|
||||||
case "url":
|
case "url":
|
||||||
errorMessage = fmt.Sprintf("%s must be a valid URL", fieldName)
|
errorMessage = fmt.Sprintf("%s must be a valid URL", fieldName)
|
||||||
case "min":
|
case "min":
|
||||||
|
|||||||
@@ -44,11 +44,9 @@ type AppConfig struct {
|
|||||||
SignupDefaultUserGroupIDs AppConfigVariable `key:"signupDefaultUserGroupIDs"`
|
SignupDefaultUserGroupIDs AppConfigVariable `key:"signupDefaultUserGroupIDs"`
|
||||||
SignupDefaultCustomClaims AppConfigVariable `key:"signupDefaultCustomClaims"`
|
SignupDefaultCustomClaims AppConfigVariable `key:"signupDefaultCustomClaims"`
|
||||||
// Internal
|
// Internal
|
||||||
BackgroundImageType AppConfigVariable `key:"backgroundImageType,internal"` // Internal
|
InstanceID AppConfigVariable `key:"instanceId,internal"` // Internal
|
||||||
LogoLightImageType AppConfigVariable `key:"logoLightImageType,internal"` // Internal
|
|
||||||
LogoDarkImageType AppConfigVariable `key:"logoDarkImageType,internal"` // Internal
|
|
||||||
InstanceID AppConfigVariable `key:"instanceId,internal"` // Internal
|
|
||||||
// Email
|
// Email
|
||||||
|
RequireUserEmail AppConfigVariable `key:"requireUserEmail,public"` // Public
|
||||||
SmtpHost AppConfigVariable `key:"smtpHost"`
|
SmtpHost AppConfigVariable `key:"smtpHost"`
|
||||||
SmtpPort AppConfigVariable `key:"smtpPort"`
|
SmtpPort AppConfigVariable `key:"smtpPort"`
|
||||||
SmtpFrom AppConfigVariable `key:"smtpFrom"`
|
SmtpFrom AppConfigVariable `key:"smtpFrom"`
|
||||||
@@ -74,6 +72,7 @@ type AppConfig struct {
|
|||||||
LdapAttributeUserEmail AppConfigVariable `key:"ldapAttributeUserEmail"`
|
LdapAttributeUserEmail AppConfigVariable `key:"ldapAttributeUserEmail"`
|
||||||
LdapAttributeUserFirstName AppConfigVariable `key:"ldapAttributeUserFirstName"`
|
LdapAttributeUserFirstName AppConfigVariable `key:"ldapAttributeUserFirstName"`
|
||||||
LdapAttributeUserLastName AppConfigVariable `key:"ldapAttributeUserLastName"`
|
LdapAttributeUserLastName AppConfigVariable `key:"ldapAttributeUserLastName"`
|
||||||
|
LdapAttributeUserDisplayName AppConfigVariable `key:"ldapAttributeUserDisplayName"`
|
||||||
LdapAttributeUserProfilePicture AppConfigVariable `key:"ldapAttributeUserProfilePicture"`
|
LdapAttributeUserProfilePicture AppConfigVariable `key:"ldapAttributeUserProfilePicture"`
|
||||||
LdapAttributeGroupMember AppConfigVariable `key:"ldapAttributeGroupMember"`
|
LdapAttributeGroupMember AppConfigVariable `key:"ldapAttributeGroupMember"`
|
||||||
LdapAttributeGroupUniqueIdentifier AppConfigVariable `key:"ldapAttributeGroupUniqueIdentifier"`
|
LdapAttributeGroupUniqueIdentifier AppConfigVariable `key:"ldapAttributeGroupUniqueIdentifier"`
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"gorm.io/gorm"
|
|
||||||
|
|
||||||
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
|
||||||
)
|
)
|
||||||
@@ -21,6 +20,14 @@ type UserAuthorizedOidcClient struct {
|
|||||||
Client OidcClient
|
Client OidcClient
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c UserAuthorizedOidcClient) Scopes() []string {
|
||||||
|
if len(c.Scope) == 0 {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Split(c.Scope, " ")
|
||||||
|
}
|
||||||
|
|
||||||
type OidcAuthorizationCode struct {
|
type OidcAuthorizationCode struct {
|
||||||
Base
|
Base
|
||||||
|
|
||||||
@@ -45,7 +52,6 @@ type OidcClient struct {
|
|||||||
CallbackURLs UrlList
|
CallbackURLs UrlList
|
||||||
LogoutCallbackURLs UrlList
|
LogoutCallbackURLs UrlList
|
||||||
ImageType *string
|
ImageType *string
|
||||||
HasLogo bool `gorm:"-"`
|
|
||||||
IsPublic bool
|
IsPublic bool
|
||||||
PkceEnabled bool
|
PkceEnabled bool
|
||||||
RequiresReauthentication bool
|
RequiresReauthentication bool
|
||||||
@@ -58,6 +64,10 @@ type OidcClient struct {
|
|||||||
UserAuthorizedOidcClients []UserAuthorizedOidcClient `gorm:"foreignKey:ClientID;references:ID"`
|
UserAuthorizedOidcClients []UserAuthorizedOidcClient `gorm:"foreignKey:ClientID;references:ID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c OidcClient) HasLogo() bool {
|
||||||
|
return c.ImageType != nil && *c.ImageType != ""
|
||||||
|
}
|
||||||
|
|
||||||
type OidcRefreshToken struct {
|
type OidcRefreshToken struct {
|
||||||
Base
|
Base
|
||||||
|
|
||||||
@@ -72,10 +82,12 @@ type OidcRefreshToken struct {
|
|||||||
Client OidcClient
|
Client OidcClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *OidcClient) AfterFind(_ *gorm.DB) (err error) {
|
func (c OidcRefreshToken) Scopes() []string {
|
||||||
// Compute HasLogo field
|
if len(c.Scope) == 0 {
|
||||||
c.HasLogo = c.ImageType != nil && *c.ImageType != ""
|
return []string{}
|
||||||
return nil
|
}
|
||||||
|
|
||||||
|
return strings.Split(c.Scope, " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
type OidcClientCredentials struct { //nolint:recvcheck
|
type OidcClientCredentials struct { //nolint:recvcheck
|
||||||
|
|||||||
@@ -13,14 +13,15 @@ import (
|
|||||||
type User struct {
|
type User struct {
|
||||||
Base
|
Base
|
||||||
|
|
||||||
Username string `sortable:"true"`
|
Username string `sortable:"true"`
|
||||||
Email string `sortable:"true"`
|
Email *string `sortable:"true"`
|
||||||
FirstName string `sortable:"true"`
|
FirstName string `sortable:"true"`
|
||||||
LastName string `sortable:"true"`
|
LastName string `sortable:"true"`
|
||||||
IsAdmin bool `sortable:"true"`
|
DisplayName string `sortable:"true"`
|
||||||
Locale *string
|
IsAdmin bool `sortable:"true"`
|
||||||
LdapID *string
|
Locale *string
|
||||||
Disabled bool `sortable:"true"`
|
LdapID *string
|
||||||
|
Disabled bool `sortable:"true"`
|
||||||
|
|
||||||
CustomClaims []CustomClaim
|
CustomClaims []CustomClaim
|
||||||
UserGroups []UserGroup `gorm:"many2many:user_groups_users;"`
|
UserGroups []UserGroup `gorm:"many2many:user_groups_users;"`
|
||||||
@@ -31,7 +32,12 @@ func (u User) WebAuthnID() []byte { return []byte(u.ID) }
|
|||||||
|
|
||||||
func (u User) WebAuthnName() string { return u.Username }
|
func (u User) WebAuthnName() string { return u.Username }
|
||||||
|
|
||||||
func (u User) WebAuthnDisplayName() string { return u.FirstName + " " + u.LastName }
|
func (u User) WebAuthnDisplayName() string {
|
||||||
|
if u.DisplayName != "" {
|
||||||
|
return u.DisplayName
|
||||||
|
}
|
||||||
|
return u.FirstName + " " + u.LastName
|
||||||
|
}
|
||||||
|
|
||||||
func (u User) WebAuthnIcon() string { return "" }
|
func (u User) WebAuthnIcon() string { return "" }
|
||||||
|
|
||||||
@@ -66,7 +72,9 @@ func (u User) WebAuthnCredentialDescriptors() (descriptors []protocol.Credential
|
|||||||
return descriptors
|
return descriptors
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u User) FullName() string { return u.FirstName + " " + u.LastName }
|
func (u User) FullName() string {
|
||||||
|
return u.FirstName + " " + u.LastName
|
||||||
|
}
|
||||||
|
|
||||||
func (u User) Initials() string {
|
func (u User) Initials() string {
|
||||||
first := utils.GetFirstCharacter(u.FirstName)
|
first := utils.GetFirstCharacter(u.FirstName)
|
||||||
|
|||||||
@@ -144,9 +144,13 @@ func (s *ApiKeyService) SendApiKeyExpiringSoonEmail(ctx context.Context, apiKey
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.Email == nil {
|
||||||
|
return &common.UserEmailNotSetError{}
|
||||||
|
}
|
||||||
|
|
||||||
err := SendEmail(ctx, s.emailService, email.Address{
|
err := SendEmail(ctx, s.emailService, email.Address{
|
||||||
Name: user.FullName(),
|
Name: user.FullName(),
|
||||||
Email: user.Email,
|
Email: *user.Email,
|
||||||
}, ApiKeyExpiringSoonTemplate, &ApiKeyExpiringSoonTemplateData{
|
}, ApiKeyExpiringSoonTemplate, &ApiKeyExpiringSoonTemplateData{
|
||||||
ApiKeyName: apiKey.Name,
|
ApiKeyName: apiKey.Name,
|
||||||
ExpiresAt: apiKey.ExpiresAt.ToTime(),
|
ExpiresAt: apiKey.ExpiresAt.ToTime(),
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"mime/multipart"
|
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -70,11 +69,9 @@ func (s *AppConfigService) getDefaultDbConfig() *model.AppConfig {
|
|||||||
SignupDefaultCustomClaims: model.AppConfigVariable{Value: "[]"},
|
SignupDefaultCustomClaims: model.AppConfigVariable{Value: "[]"},
|
||||||
AccentColor: model.AppConfigVariable{Value: "default"},
|
AccentColor: model.AppConfigVariable{Value: "default"},
|
||||||
// Internal
|
// Internal
|
||||||
BackgroundImageType: model.AppConfigVariable{Value: "jpg"},
|
InstanceID: model.AppConfigVariable{Value: ""},
|
||||||
LogoLightImageType: model.AppConfigVariable{Value: "svg"},
|
|
||||||
LogoDarkImageType: model.AppConfigVariable{Value: "svg"},
|
|
||||||
InstanceID: model.AppConfigVariable{Value: ""},
|
|
||||||
// Email
|
// Email
|
||||||
|
RequireUserEmail: model.AppConfigVariable{Value: "true"},
|
||||||
SmtpHost: model.AppConfigVariable{},
|
SmtpHost: model.AppConfigVariable{},
|
||||||
SmtpPort: model.AppConfigVariable{},
|
SmtpPort: model.AppConfigVariable{},
|
||||||
SmtpFrom: model.AppConfigVariable{},
|
SmtpFrom: model.AppConfigVariable{},
|
||||||
@@ -100,6 +97,7 @@ func (s *AppConfigService) getDefaultDbConfig() *model.AppConfig {
|
|||||||
LdapAttributeUserEmail: model.AppConfigVariable{},
|
LdapAttributeUserEmail: model.AppConfigVariable{},
|
||||||
LdapAttributeUserFirstName: model.AppConfigVariable{},
|
LdapAttributeUserFirstName: model.AppConfigVariable{},
|
||||||
LdapAttributeUserLastName: model.AppConfigVariable{},
|
LdapAttributeUserLastName: model.AppConfigVariable{},
|
||||||
|
LdapAttributeUserDisplayName: model.AppConfigVariable{Value: "cn"},
|
||||||
LdapAttributeUserProfilePicture: model.AppConfigVariable{},
|
LdapAttributeUserProfilePicture: model.AppConfigVariable{},
|
||||||
LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"},
|
LdapAttributeGroupMember: model.AppConfigVariable{Value: "member"},
|
||||||
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{},
|
LdapAttributeGroupUniqueIdentifier: model.AppConfigVariable{},
|
||||||
@@ -321,39 +319,6 @@ func (s *AppConfigService) ListAppConfig(showAll bool) []model.AppConfigVariable
|
|||||||
return s.GetDbConfig().ToAppConfigVariableSlice(showAll, true)
|
return s.GetDbConfig().ToAppConfigVariableSlice(showAll, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *AppConfigService) UpdateImage(ctx context.Context, uploadedFile *multipart.FileHeader, imageName string, oldImageType string) (err error) {
|
|
||||||
fileType := strings.ToLower(utils.GetFileExtension(uploadedFile.Filename))
|
|
||||||
mimeType := utils.GetImageMimeType(fileType)
|
|
||||||
if mimeType == "" {
|
|
||||||
return &common.FileTypeNotSupportedError{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save the updated image
|
|
||||||
imagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + fileType
|
|
||||||
err = utils.SaveFile(uploadedFile, imagePath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete the old image if it has a different file type, then update the type in the database
|
|
||||||
if fileType != oldImageType {
|
|
||||||
oldImagePath := common.EnvConfig.UploadPath + "/application-images/" + imageName + "." + oldImageType
|
|
||||||
err = os.Remove(oldImagePath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the file type in the database
|
|
||||||
err = s.UpdateAppConfigValues(ctx, imageName+"ImageType", fileType)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadDbConfig loads the configuration values from the database into the DbConfig struct.
|
// LoadDbConfig loads the configuration values from the database into the DbConfig struct.
|
||||||
func (s *AppConfigService) LoadDbConfig(ctx context.Context) (err error) {
|
func (s *AppConfigService) LoadDbConfig(ctx context.Context) (err error) {
|
||||||
dest, err := s.loadDbConfigInternal(ctx, s.db)
|
dest, err := s.loadDbConfigInternal(ctx, s.db)
|
||||||
|
|||||||
82
backend/internal/service/app_images_service.go
Normal file
82
backend/internal/service/app_images_service.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"mime/multipart"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AppImagesService struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
extensions map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAppImagesService(extensions map[string]string) *AppImagesService {
|
||||||
|
return &AppImagesService{extensions: extensions}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AppImagesService) GetImage(name string) (string, string, error) {
|
||||||
|
ext, err := s.getExtension(name)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
mimeType := utils.GetImageMimeType(ext)
|
||||||
|
if mimeType == "" {
|
||||||
|
return "", "", fmt.Errorf("unsupported image type '%s'", ext)
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", fmt.Sprintf("%s.%s", name, ext))
|
||||||
|
return imagePath, mimeType, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AppImagesService) UpdateImage(file *multipart.FileHeader, imageName string) error {
|
||||||
|
fileType := strings.ToLower(utils.GetFileExtension(file.Filename))
|
||||||
|
mimeType := utils.GetImageMimeType(fileType)
|
||||||
|
if mimeType == "" {
|
||||||
|
return &common.FileTypeNotSupportedError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
currentExt, ok := s.extensions[imageName]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unknown application image '%s'", imageName)
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", fmt.Sprintf("%s.%s", imageName, fileType))
|
||||||
|
|
||||||
|
if err := utils.SaveFile(file, imagePath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentExt != "" && currentExt != fileType {
|
||||||
|
oldImagePath := filepath.Join(common.EnvConfig.UploadPath, "application-images", fmt.Sprintf("%s.%s", imageName, currentExt))
|
||||||
|
if err := os.Remove(oldImagePath); err != nil && !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.extensions[imageName] = fileType
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *AppImagesService) getExtension(name string) (string, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
ext, ok := s.extensions[name]
|
||||||
|
if !ok || ext == "" {
|
||||||
|
return "", fmt.Errorf("unknown application image '%s'", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.ToLower(ext), nil
|
||||||
|
}
|
||||||
88
backend/internal/service/app_images_service_test.go
Normal file
88
backend/internal/service/app_images_service_test.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io/fs"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAppImagesService_GetImage(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
originalUploadPath := common.EnvConfig.UploadPath
|
||||||
|
common.EnvConfig.UploadPath = tempDir
|
||||||
|
t.Cleanup(func() {
|
||||||
|
common.EnvConfig.UploadPath = originalUploadPath
|
||||||
|
})
|
||||||
|
|
||||||
|
imagesDir := filepath.Join(tempDir, "application-images")
|
||||||
|
require.NoError(t, os.MkdirAll(imagesDir, 0o755))
|
||||||
|
|
||||||
|
filePath := filepath.Join(imagesDir, "background.webp")
|
||||||
|
require.NoError(t, os.WriteFile(filePath, []byte("data"), fs.FileMode(0o644)))
|
||||||
|
|
||||||
|
service := NewAppImagesService(map[string]string{"background": "webp"})
|
||||||
|
|
||||||
|
path, mimeType, err := service.GetImage("background")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, filePath, path)
|
||||||
|
require.Equal(t, "image/webp", mimeType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppImagesService_UpdateImage(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
originalUploadPath := common.EnvConfig.UploadPath
|
||||||
|
common.EnvConfig.UploadPath = tempDir
|
||||||
|
t.Cleanup(func() {
|
||||||
|
common.EnvConfig.UploadPath = originalUploadPath
|
||||||
|
})
|
||||||
|
|
||||||
|
imagesDir := filepath.Join(tempDir, "application-images")
|
||||||
|
require.NoError(t, os.MkdirAll(imagesDir, 0o755))
|
||||||
|
|
||||||
|
oldPath := filepath.Join(imagesDir, "logoLight.svg")
|
||||||
|
require.NoError(t, os.WriteFile(oldPath, []byte("old"), fs.FileMode(0o644)))
|
||||||
|
|
||||||
|
service := NewAppImagesService(map[string]string{"logoLight": "svg"})
|
||||||
|
|
||||||
|
fileHeader := newFileHeader(t, "logoLight.png", []byte("new"))
|
||||||
|
|
||||||
|
require.NoError(t, service.UpdateImage(fileHeader, "logoLight"))
|
||||||
|
|
||||||
|
_, err := os.Stat(filepath.Join(imagesDir, "logoLight.png"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = os.Stat(oldPath)
|
||||||
|
require.ErrorIs(t, err, os.ErrNotExist)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileHeader(t *testing.T, filename string, content []byte) *multipart.FileHeader {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
|
||||||
|
part, err := writer.CreateFormFile("file", filename)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = part.Write(content)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, writer.Close())
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", body)
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
_, fileHeader, err := req.FormFile("file")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return fileHeader
|
||||||
|
}
|
||||||
@@ -111,9 +111,13 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddres
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.Email == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
innerErr = SendEmail(innerCtx, s.emailService, email.Address{
|
innerErr = SendEmail(innerCtx, s.emailService, email.Address{
|
||||||
Name: user.FullName(),
|
Name: user.FullName(),
|
||||||
Email: user.Email,
|
Email: *user.Email,
|
||||||
}, NewLoginTemplate, &NewLoginTemplateData{
|
}, NewLoginTemplate, &NewLoginTemplateData{
|
||||||
IPAddress: ipAddress,
|
IPAddress: ipAddress,
|
||||||
Country: createdAuditLog.Country,
|
Country: createdAuditLog.Country,
|
||||||
@@ -122,7 +126,7 @@ func (s *AuditLogService) CreateNewSignInWithEmail(ctx context.Context, ipAddres
|
|||||||
DateTime: createdAuditLog.CreatedAt.UTC(),
|
DateTime: createdAuditLog.CreatedAt.UTC(),
|
||||||
})
|
})
|
||||||
if innerErr != nil {
|
if innerErr != nil {
|
||||||
slog.ErrorContext(innerCtx, "Failed to send notification email", slog.Any("error", innerErr), slog.String("address", user.Email))
|
slog.ErrorContext(innerCtx, "Failed to send notification email", slog.Any("error", innerErr), slog.String("address", *user.Email))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ func isReservedClaim(key string) bool {
|
|||||||
"name",
|
"name",
|
||||||
"email",
|
"email",
|
||||||
"preferred_username",
|
"preferred_username",
|
||||||
|
"display_name",
|
||||||
"groups",
|
"groups",
|
||||||
TokenTypeClaim,
|
TokenTypeClaim,
|
||||||
"sub",
|
"sub",
|
||||||
|
|||||||
@@ -78,21 +78,23 @@ func (s *TestService) SeedDatabase(baseURL string) error {
|
|||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "f4b89dc2-62fb-46bf-9f5f-c34f4eafe93e",
|
ID: "f4b89dc2-62fb-46bf-9f5f-c34f4eafe93e",
|
||||||
},
|
},
|
||||||
Username: "tim",
|
Username: "tim",
|
||||||
Email: "tim.cook@test.com",
|
Email: utils.Ptr("tim.cook@test.com"),
|
||||||
FirstName: "Tim",
|
FirstName: "Tim",
|
||||||
LastName: "Cook",
|
LastName: "Cook",
|
||||||
IsAdmin: true,
|
DisplayName: "Tim Cook",
|
||||||
|
IsAdmin: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "1cd19686-f9a6-43f4-a41f-14a0bf5b4036",
|
ID: "1cd19686-f9a6-43f4-a41f-14a0bf5b4036",
|
||||||
},
|
},
|
||||||
Username: "craig",
|
Username: "craig",
|
||||||
Email: "craig.federighi@test.com",
|
Email: utils.Ptr("craig.federighi@test.com"),
|
||||||
FirstName: "Craig",
|
FirstName: "Craig",
|
||||||
LastName: "Federighi",
|
LastName: "Federighi",
|
||||||
IsAdmin: false,
|
DisplayName: "Craig Federighi",
|
||||||
|
IsAdmin: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, user := range users {
|
for _, user := range users {
|
||||||
|
|||||||
@@ -62,9 +62,13 @@ func (srv *EmailService) SendTestEmail(ctx context.Context, recipientUserId stri
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.Email == nil {
|
||||||
|
return &common.UserEmailNotSetError{}
|
||||||
|
}
|
||||||
|
|
||||||
return SendEmail(ctx, srv,
|
return SendEmail(ctx, srv,
|
||||||
email.Address{
|
email.Address{
|
||||||
Email: user.Email,
|
Email: *user.Email,
|
||||||
Name: user.FullName(),
|
Name: user.FullName(),
|
||||||
}, TestTemplate, nil)
|
}, TestTemplate, nil)
|
||||||
}
|
}
|
||||||
@@ -74,7 +78,7 @@ func SendEmail[V any](ctx context.Context, srv *EmailService, toEmail email.Addr
|
|||||||
|
|
||||||
data := &email.TemplateData[V]{
|
data := &email.TemplateData[V]{
|
||||||
AppName: dbConfig.AppName.Value,
|
AppName: dbConfig.AppName.Value,
|
||||||
LogoURL: common.EnvConfig.AppURL + "/api/application-configuration/logo",
|
LogoURL: common.EnvConfig.AppURL + "/api/application-images/logo",
|
||||||
Data: tData,
|
Data: tData,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -262,7 +266,7 @@ func prepareBody[V any](srv *EmailService, template email.Template[V], data *ema
|
|||||||
|
|
||||||
// prepare text part
|
// prepare text part
|
||||||
var textHeader = textproto.MIMEHeader{}
|
var textHeader = textproto.MIMEHeader{}
|
||||||
textHeader.Add("Content-Type", "text/plain;\n charset=UTF-8")
|
textHeader.Add("Content-Type", "text/plain; charset=UTF-8")
|
||||||
textHeader.Add("Content-Transfer-Encoding", "quoted-printable")
|
textHeader.Add("Content-Transfer-Encoding", "quoted-printable")
|
||||||
textPart, err := mpart.CreatePart(textHeader)
|
textPart, err := mpart.CreatePart(textHeader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -274,18 +278,17 @@ func prepareBody[V any](srv *EmailService, template email.Template[V], data *ema
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("execute text template: %w", err)
|
return "", "", fmt.Errorf("execute text template: %w", err)
|
||||||
}
|
}
|
||||||
|
textQp.Close()
|
||||||
|
|
||||||
// prepare html part
|
|
||||||
var htmlHeader = textproto.MIMEHeader{}
|
var htmlHeader = textproto.MIMEHeader{}
|
||||||
htmlHeader.Add("Content-Type", "text/html;\n charset=UTF-8")
|
htmlHeader.Add("Content-Type", "text/html; charset=UTF-8")
|
||||||
htmlHeader.Add("Content-Transfer-Encoding", "quoted-printable")
|
htmlHeader.Add("Content-Transfer-Encoding", "8bit")
|
||||||
htmlPart, err := mpart.CreatePart(htmlHeader)
|
htmlPart, err := mpart.CreatePart(htmlHeader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("create html part: %w", err)
|
return "", "", fmt.Errorf("create html part: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
htmlQp := quotedprintable.NewWriter(htmlPart)
|
err = email.GetTemplate(srv.htmlTemplates, template).ExecuteTemplate(htmlPart, "root", data)
|
||||||
err = email.GetTemplate(srv.htmlTemplates, template).ExecuteTemplate(htmlQp, "root", data)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("execute html template: %w", err)
|
return "", "", fmt.Errorf("execute html template: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,35 +13,19 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oschwald/maxminddb-golang/v2"
|
"github.com/oschwald/maxminddb-golang/v2"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GeoLiteService struct {
|
type GeoLiteService struct {
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
disableUpdater bool
|
disableUpdater bool
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
localIPv6Ranges []*net.IPNet
|
|
||||||
}
|
|
||||||
|
|
||||||
var localhostIPNets = []*net.IPNet{
|
|
||||||
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
|
|
||||||
{IP: net.IPv6loopback, Mask: net.CIDRMask(128, 128)}, // ::1/128
|
|
||||||
}
|
|
||||||
|
|
||||||
var privateLanIPNets = []*net.IPNet{
|
|
||||||
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
|
|
||||||
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
|
|
||||||
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
|
|
||||||
}
|
|
||||||
|
|
||||||
var tailscaleIPNets = []*net.IPNet{
|
|
||||||
{IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)}, // 100.64.0.0/10
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database.
|
// NewGeoLiteService initializes a new GeoLiteService instance and starts a goroutine to update the GeoLite2 City database.
|
||||||
@@ -56,67 +40,9 @@ func NewGeoLiteService(httpClient *http.Client) *GeoLiteService {
|
|||||||
service.disableUpdater = true
|
service.disableUpdater = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize IPv6 local ranges
|
|
||||||
err := service.initializeIPv6LocalRanges()
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("Failed to initialize IPv6 local ranges", slog.Any("error", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return service
|
return service
|
||||||
}
|
}
|
||||||
|
|
||||||
// initializeIPv6LocalRanges parses the LOCAL_IPV6_RANGES environment variable
|
|
||||||
func (s *GeoLiteService) initializeIPv6LocalRanges() error {
|
|
||||||
rangesEnv := common.EnvConfig.LocalIPv6Ranges
|
|
||||||
if rangesEnv == "" {
|
|
||||||
return nil // No local IPv6 ranges configured
|
|
||||||
}
|
|
||||||
|
|
||||||
ranges := strings.Split(rangesEnv, ",")
|
|
||||||
localRanges := make([]*net.IPNet, 0, len(ranges))
|
|
||||||
|
|
||||||
for _, rangeStr := range ranges {
|
|
||||||
rangeStr = strings.TrimSpace(rangeStr)
|
|
||||||
if rangeStr == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ipNet, err := net.ParseCIDR(rangeStr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid IPv6 range '%s': %w", rangeStr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure it's an IPv6 range
|
|
||||||
if ipNet.IP.To4() != nil {
|
|
||||||
return fmt.Errorf("range '%s' is not a valid IPv6 range", rangeStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
localRanges = append(localRanges, ipNet)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.localIPv6Ranges = localRanges
|
|
||||||
|
|
||||||
if len(localRanges) > 0 {
|
|
||||||
slog.Info("Initialized IPv6 local ranges", slog.Int("count", len(localRanges)))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isLocalIPv6 checks if the given IPv6 address is within any of the configured local ranges
|
|
||||||
func (s *GeoLiteService) isLocalIPv6(ip net.IP) bool {
|
|
||||||
if ip.To4() != nil {
|
|
||||||
return false // Not an IPv6 address
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, localRange := range s.localIPv6Ranges {
|
|
||||||
if localRange.Contains(ip) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *GeoLiteService) DisableUpdater() bool {
|
func (s *GeoLiteService) DisableUpdater() bool {
|
||||||
return s.disableUpdater
|
return s.disableUpdater
|
||||||
}
|
}
|
||||||
@@ -129,26 +55,17 @@ func (s *GeoLiteService) GetLocationByIP(ipAddress string) (country, city string
|
|||||||
|
|
||||||
// Check the IP address against known private IP ranges
|
// Check the IP address against known private IP ranges
|
||||||
if ip := net.ParseIP(ipAddress); ip != nil {
|
if ip := net.ParseIP(ipAddress); ip != nil {
|
||||||
// Check IPv6 local ranges first
|
if utils.IsLocalIPv6(ip) {
|
||||||
if s.isLocalIPv6(ip) {
|
|
||||||
return "Internal Network", "LAN", nil
|
return "Internal Network", "LAN", nil
|
||||||
}
|
}
|
||||||
|
if utils.IsTailscaleIP(ip) {
|
||||||
// Check existing IPv4 ranges
|
return "Internal Network", "Tailscale", nil
|
||||||
for _, ipNet := range tailscaleIPNets {
|
|
||||||
if ipNet.Contains(ip) {
|
|
||||||
return "Internal Network", "Tailscale", nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for _, ipNet := range privateLanIPNets {
|
if utils.IsPrivateIP(ip) {
|
||||||
if ipNet.Contains(ip) {
|
return "Internal Network", "LAN", nil
|
||||||
return "Internal Network", "LAN", nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for _, ipNet := range localhostIPNets {
|
if utils.IsLocalhostIP(ip) {
|
||||||
if ipNet.Contains(ip) {
|
return "Internal Network", "localhost", nil
|
||||||
return "Internal Network", "localhost", nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,220 +0,0 @@
|
|||||||
package service
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGeoLiteService_IPv6LocalRanges(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
localRanges string
|
|
||||||
testIP string
|
|
||||||
expectedCountry string
|
|
||||||
expectedCity string
|
|
||||||
expectError bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "IPv6 in local range",
|
|
||||||
localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
|
|
||||||
testIP: "2001:0db8:abcd:000::1",
|
|
||||||
expectedCountry: "Internal Network",
|
|
||||||
expectedCity: "LAN",
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 not in local range",
|
|
||||||
localRanges: "2001:0db8:abcd:000::/56",
|
|
||||||
testIP: "2001:0db8:ffff:000::1",
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Multiple ranges - second range match",
|
|
||||||
localRanges: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
|
|
||||||
testIP: "2001:0db8:abcd:001::1",
|
|
||||||
expectedCountry: "Internal Network",
|
|
||||||
expectedCity: "LAN",
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty local ranges",
|
|
||||||
localRanges: "",
|
|
||||||
testIP: "2001:0db8:abcd:000::1",
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv4 private address still works",
|
|
||||||
localRanges: "2001:0db8:abcd:000::/56",
|
|
||||||
testIP: "192.168.1.1",
|
|
||||||
expectedCountry: "Internal Network",
|
|
||||||
expectedCity: "LAN",
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 loopback",
|
|
||||||
localRanges: "2001:0db8:abcd:000::/56",
|
|
||||||
testIP: "::1",
|
|
||||||
expectedCountry: "Internal Network",
|
|
||||||
expectedCity: "localhost",
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
originalConfig := common.EnvConfig.LocalIPv6Ranges
|
|
||||||
common.EnvConfig.LocalIPv6Ranges = tt.localRanges
|
|
||||||
defer func() {
|
|
||||||
common.EnvConfig.LocalIPv6Ranges = originalConfig
|
|
||||||
}()
|
|
||||||
|
|
||||||
service := NewGeoLiteService(&http.Client{})
|
|
||||||
|
|
||||||
country, city, err := service.GetLocationByIP(tt.testIP)
|
|
||||||
|
|
||||||
if tt.expectError {
|
|
||||||
if err == nil && country != "Internal Network" {
|
|
||||||
t.Errorf("Expected error or internal network classification for external IP")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, tt.expectedCountry, country)
|
|
||||||
assert.Equal(t, tt.expectedCity, city)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGeoLiteService_isLocalIPv6(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
localRanges string
|
|
||||||
testIP string
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Valid IPv6 in range",
|
|
||||||
localRanges: "2001:0db8:abcd:000::/56",
|
|
||||||
testIP: "2001:0db8:abcd:000::1",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Valid IPv6 not in range",
|
|
||||||
localRanges: "2001:0db8:abcd:000::/56",
|
|
||||||
testIP: "2001:0db8:ffff:000::1",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv4 address should return false",
|
|
||||||
localRanges: "2001:0db8:abcd:000::/56",
|
|
||||||
testIP: "192.168.1.1",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "No ranges configured",
|
|
||||||
localRanges: "",
|
|
||||||
testIP: "2001:0db8:abcd:000::1",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Edge of range",
|
|
||||||
localRanges: "2001:0db8:abcd:000::/56",
|
|
||||||
testIP: "2001:0db8:abcd:00ff:ffff:ffff:ffff:ffff",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
originalConfig := common.EnvConfig.LocalIPv6Ranges
|
|
||||||
common.EnvConfig.LocalIPv6Ranges = tt.localRanges
|
|
||||||
defer func() {
|
|
||||||
common.EnvConfig.LocalIPv6Ranges = originalConfig
|
|
||||||
}()
|
|
||||||
|
|
||||||
service := NewGeoLiteService(&http.Client{})
|
|
||||||
ip := net.ParseIP(tt.testIP)
|
|
||||||
if ip == nil {
|
|
||||||
t.Fatalf("Invalid test IP: %s", tt.testIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := service.isLocalIPv6(ip)
|
|
||||||
assert.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGeoLiteService_initializeIPv6LocalRanges(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
envValue string
|
|
||||||
expectError bool
|
|
||||||
expectCount int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Valid IPv6 ranges",
|
|
||||||
envValue: "2001:0db8:abcd:000::/56,2001:0db8:abcd:001::/56",
|
|
||||||
expectError: false,
|
|
||||||
expectCount: 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty environment variable",
|
|
||||||
envValue: "",
|
|
||||||
expectError: false,
|
|
||||||
expectCount: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Invalid CIDR notation",
|
|
||||||
envValue: "2001:0db8:abcd:000::/999",
|
|
||||||
expectError: true,
|
|
||||||
expectCount: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv4 range in IPv6 env var",
|
|
||||||
envValue: "192.168.1.0/24",
|
|
||||||
expectError: true,
|
|
||||||
expectCount: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Mixed valid and invalid ranges",
|
|
||||||
envValue: "2001:0db8:abcd:000::/56,invalid-range",
|
|
||||||
expectError: true,
|
|
||||||
expectCount: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Whitespace handling",
|
|
||||||
envValue: " 2001:0db8:abcd:000::/56 , 2001:0db8:abcd:001::/56 ",
|
|
||||||
expectError: false,
|
|
||||||
expectCount: 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
originalConfig := common.EnvConfig.LocalIPv6Ranges
|
|
||||||
common.EnvConfig.LocalIPv6Ranges = tt.envValue
|
|
||||||
defer func() {
|
|
||||||
common.EnvConfig.LocalIPv6Ranges = originalConfig
|
|
||||||
}()
|
|
||||||
|
|
||||||
service := &GeoLiteService{
|
|
||||||
httpClient: &http.Client{},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := service.initializeIPv6LocalRanges()
|
|
||||||
|
|
||||||
if tt.expectError {
|
|
||||||
require.Error(t, err)
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Len(t, service.localIPv6Ranges, tt.expectCount)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/lestrrat-go/jwx/v3/jwa"
|
"github.com/lestrrat-go/jwx/v3/jwa"
|
||||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
@@ -342,7 +343,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "user123",
|
ID: "user123",
|
||||||
},
|
},
|
||||||
Email: "user@example.com",
|
Email: utils.Ptr("user@example.com"),
|
||||||
IsAdmin: false,
|
IsAdmin: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -385,7 +386,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "admin123",
|
ID: "admin123",
|
||||||
},
|
},
|
||||||
Email: "admin@example.com",
|
Email: utils.Ptr("admin@example.com"),
|
||||||
IsAdmin: true,
|
IsAdmin: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -464,7 +465,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "eddsauser123",
|
ID: "eddsauser123",
|
||||||
},
|
},
|
||||||
Email: "eddsauser@example.com",
|
Email: utils.Ptr("eddsauser@example.com"),
|
||||||
IsAdmin: true,
|
IsAdmin: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -521,7 +522,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "ecdsauser123",
|
ID: "ecdsauser123",
|
||||||
},
|
},
|
||||||
Email: "ecdsauser@example.com",
|
Email: utils.Ptr("ecdsauser@example.com"),
|
||||||
IsAdmin: true,
|
IsAdmin: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -578,7 +579,7 @@ func TestGenerateVerifyAccessToken(t *testing.T) {
|
|||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "rsauser123",
|
ID: "rsauser123",
|
||||||
},
|
},
|
||||||
Email: "rsauser@example.com",
|
Email: utils.Ptr("rsauser@example.com"),
|
||||||
IsAdmin: true,
|
IsAdmin: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -965,7 +966,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "user123",
|
ID: "user123",
|
||||||
},
|
},
|
||||||
Email: "user@example.com",
|
Email: utils.Ptr("user@example.com"),
|
||||||
}
|
}
|
||||||
const clientID = "test-client-123"
|
const clientID = "test-client-123"
|
||||||
|
|
||||||
@@ -1092,7 +1093,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "eddsauser789",
|
ID: "eddsauser789",
|
||||||
},
|
},
|
||||||
Email: "eddsaoauth@example.com",
|
Email: utils.Ptr("eddsaoauth@example.com"),
|
||||||
}
|
}
|
||||||
const clientID = "eddsa-oauth-client"
|
const clientID = "eddsa-oauth-client"
|
||||||
|
|
||||||
@@ -1149,7 +1150,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "ecdsauser789",
|
ID: "ecdsauser789",
|
||||||
},
|
},
|
||||||
Email: "ecdsaoauth@example.com",
|
Email: utils.Ptr("ecdsaoauth@example.com"),
|
||||||
}
|
}
|
||||||
const clientID = "ecdsa-oauth-client"
|
const clientID = "ecdsa-oauth-client"
|
||||||
|
|
||||||
@@ -1206,7 +1207,7 @@ func TestGenerateVerifyOAuthAccessToken(t *testing.T) {
|
|||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: "rsauser789",
|
ID: "rsauser789",
|
||||||
},
|
},
|
||||||
Email: "rsaoauth@example.com",
|
Email: utils.Ptr("rsaoauth@example.com"),
|
||||||
}
|
}
|
||||||
const clientID = "rsa-oauth-client"
|
const clientID = "rsa-oauth-client"
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
"github.com/go-ldap/ldap/v3"
|
"github.com/go-ldap/ldap/v3"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
"golang.org/x/text/unicode/norm"
|
"golang.org/x/text/unicode/norm"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
@@ -179,10 +180,12 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
username = norm.NFC.String(username)
|
||||||
|
|
||||||
var databaseUser model.User
|
var databaseUser model.User
|
||||||
err = tx.
|
err = tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Where("username = ? AND ldap_id IS NOT NULL", norm.NFC.String(username)).
|
Where("username = ? AND ldap_id IS NOT NULL", username).
|
||||||
First(&databaseUser).
|
First(&databaseUser).
|
||||||
Error
|
Error
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
@@ -202,6 +205,12 @@ func (s *LdapService) SyncGroups(ctx context.Context, tx *gorm.DB, client *ldap.
|
|||||||
}
|
}
|
||||||
dto.Normalize(syncGroup)
|
dto.Normalize(syncGroup)
|
||||||
|
|
||||||
|
err = syncGroup.Validate()
|
||||||
|
if err != nil {
|
||||||
|
slog.WarnContext(ctx, "LDAP user group object is not valid", slog.Any("error", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if databaseGroup.ID == "" {
|
if databaseGroup.ID == "" {
|
||||||
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
newGroup, err := s.groupService.createInternal(ctx, syncGroup, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -270,6 +279,7 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
dbConfig.LdapAttributeUserFirstName.Value,
|
dbConfig.LdapAttributeUserFirstName.Value,
|
||||||
dbConfig.LdapAttributeUserLastName.Value,
|
dbConfig.LdapAttributeUserLastName.Value,
|
||||||
dbConfig.LdapAttributeUserProfilePicture.Value,
|
dbConfig.LdapAttributeUserProfilePicture.Value,
|
||||||
|
dbConfig.LdapAttributeUserDisplayName.Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filters must start and finish with ()!
|
// Filters must start and finish with ()!
|
||||||
@@ -338,15 +348,27 @@ func (s *LdapService) SyncUsers(ctx context.Context, tx *gorm.DB, client *ldap.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
newUser := dto.UserCreateDto{
|
newUser := dto.UserCreateDto{
|
||||||
Username: value.GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value),
|
Username: value.GetAttributeValue(dbConfig.LdapAttributeUserUsername.Value),
|
||||||
Email: value.GetAttributeValue(dbConfig.LdapAttributeUserEmail.Value),
|
Email: utils.PtrOrNil(value.GetAttributeValue(dbConfig.LdapAttributeUserEmail.Value)),
|
||||||
FirstName: value.GetAttributeValue(dbConfig.LdapAttributeUserFirstName.Value),
|
FirstName: value.GetAttributeValue(dbConfig.LdapAttributeUserFirstName.Value),
|
||||||
LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value),
|
LastName: value.GetAttributeValue(dbConfig.LdapAttributeUserLastName.Value),
|
||||||
IsAdmin: isAdmin,
|
DisplayName: value.GetAttributeValue(dbConfig.LdapAttributeUserDisplayName.Value),
|
||||||
LdapID: ldapId,
|
IsAdmin: isAdmin,
|
||||||
|
LdapID: ldapId,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if newUser.DisplayName == "" {
|
||||||
|
newUser.DisplayName = strings.TrimSpace(newUser.FirstName + " " + newUser.LastName)
|
||||||
|
}
|
||||||
|
|
||||||
dto.Normalize(newUser)
|
dto.Normalize(newUser)
|
||||||
|
|
||||||
|
err = newUser.Validate()
|
||||||
|
if err != nil {
|
||||||
|
slog.WarnContext(ctx, "LDAP user object is not valid", slog.Any("error", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if databaseUser.ID == "" {
|
if databaseUser.ID == "" {
|
||||||
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
|
_, err = s.userService.createUserInternal(ctx, newUser, true, tx)
|
||||||
if errors.Is(err, &common.AlreadyInUseError{}) {
|
if errors.Is(err, &common.AlreadyInUseError{}) {
|
||||||
|
|||||||
@@ -8,10 +8,14 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -37,9 +41,11 @@ const (
|
|||||||
GrantTypeAuthorizationCode = "authorization_code"
|
GrantTypeAuthorizationCode = "authorization_code"
|
||||||
GrantTypeRefreshToken = "refresh_token"
|
GrantTypeRefreshToken = "refresh_token"
|
||||||
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
GrantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||||
|
GrantTypeClientCredentials = "client_credentials"
|
||||||
|
|
||||||
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
|
ClientAssertionTypeJWTBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" //nolint:gosec
|
||||||
|
|
||||||
|
AccessTokenDuration = time.Hour
|
||||||
RefreshTokenDuration = 30 * 24 * time.Hour // 30 days
|
RefreshTokenDuration = 30 * 24 * time.Hour // 30 days
|
||||||
DeviceCodeDuration = 15 * time.Minute
|
DeviceCodeDuration = 15 * time.Minute
|
||||||
)
|
)
|
||||||
@@ -64,6 +70,7 @@ func NewOidcService(
|
|||||||
auditLogService *AuditLogService,
|
auditLogService *AuditLogService,
|
||||||
customClaimService *CustomClaimService,
|
customClaimService *CustomClaimService,
|
||||||
webAuthnService *WebAuthnService,
|
webAuthnService *WebAuthnService,
|
||||||
|
httpClient *http.Client,
|
||||||
) (s *OidcService, err error) {
|
) (s *OidcService, err error) {
|
||||||
s = &OidcService{
|
s = &OidcService{
|
||||||
db: db,
|
db: db,
|
||||||
@@ -72,6 +79,7 @@ func NewOidcService(
|
|||||||
auditLogService: auditLogService,
|
auditLogService: auditLogService,
|
||||||
customClaimService: customClaimService,
|
customClaimService: customClaimService,
|
||||||
webAuthnService: webAuthnService,
|
webAuthnService: webAuthnService,
|
||||||
|
httpClient: httpClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
|
// Note: we don't pass the HTTP Client with OTel instrumented to this because requests are always made in background and not tied to a specific trace
|
||||||
@@ -247,6 +255,8 @@ func (s *OidcService) CreateTokens(ctx context.Context, input dto.OidcCreateToke
|
|||||||
return s.createTokenFromRefreshToken(ctx, input)
|
return s.createTokenFromRefreshToken(ctx, input)
|
||||||
case GrantTypeDeviceCode:
|
case GrantTypeDeviceCode:
|
||||||
return s.createTokenFromDeviceCode(ctx, input)
|
return s.createTokenFromDeviceCode(ctx, input)
|
||||||
|
case GrantTypeClientCredentials:
|
||||||
|
return s.createTokenFromClientCredentials(ctx, input)
|
||||||
default:
|
default:
|
||||||
return CreatedTokens{}, &common.OidcGrantTypeNotSupportedError{}
|
return CreatedTokens{}, &common.OidcGrantTypeNotSupportedError{}
|
||||||
}
|
}
|
||||||
@@ -329,7 +339,35 @@ func (s *OidcService) createTokenFromDeviceCode(ctx context.Context, input dto.O
|
|||||||
IdToken: idToken,
|
IdToken: idToken,
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
ExpiresIn: time.Hour,
|
ExpiresIn: AccessTokenDuration,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) createTokenFromClientCredentials(ctx context.Context, input dto.OidcCreateTokensDto) (CreatedTokens, error) {
|
||||||
|
client, err := s.verifyClientCredentialsInternal(ctx, s.db, clientAuthCredentialsFromCreateTokensDto(&input), false)
|
||||||
|
if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateOAuthAccessToken uses user.ID as a "sub" claim. Prefix is used to take those security considerations
|
||||||
|
// into account: https://datatracker.ietf.org/doc/html/rfc9068#name-security-considerations
|
||||||
|
dummyUser := model.User{
|
||||||
|
Base: model.Base{ID: "client-" + client.ID},
|
||||||
|
}
|
||||||
|
|
||||||
|
audClaim := client.ID
|
||||||
|
if input.Resource != "" {
|
||||||
|
audClaim = input.Resource
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, err := s.jwtService.GenerateOAuthAccessToken(dummyUser, audClaim)
|
||||||
|
if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return CreatedTokens{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
ExpiresIn: AccessTokenDuration,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -403,7 +441,7 @@ func (s *OidcService) createTokenFromAuthorizationCode(ctx context.Context, inpu
|
|||||||
IdToken: idToken,
|
IdToken: idToken,
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
ExpiresIn: time.Hour,
|
ExpiresIn: AccessTokenDuration,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -437,7 +475,7 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
|||||||
var storedRefreshToken model.OidcRefreshToken
|
var storedRefreshToken model.OidcRefreshToken
|
||||||
err = tx.
|
err = tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Preload("User").
|
Preload("User.UserGroups").
|
||||||
Where(
|
Where(
|
||||||
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
|
"token = ? AND expires_at > ? AND user_id = ? AND client_id = ?",
|
||||||
utils.CreateSha256Hash(rt),
|
utils.CreateSha256Hash(rt),
|
||||||
@@ -447,10 +485,9 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
|||||||
).
|
).
|
||||||
First(&storedRefreshToken).
|
First(&storedRefreshToken).
|
||||||
Error
|
Error
|
||||||
if err != nil {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
||||||
return CreatedTokens{}, &common.OidcInvalidRefreshTokenError{}
|
} else if err != nil {
|
||||||
}
|
|
||||||
return CreatedTokens{}, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -465,6 +502,19 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
|||||||
return CreatedTokens{}, err
|
return CreatedTokens{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Load the profile, which we need for the ID token
|
||||||
|
userClaims, err := s.getUserClaims(ctx, &storedRefreshToken.User, storedRefreshToken.Scopes(), tx)
|
||||||
|
if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new ID token
|
||||||
|
// There's no nonce here because we don't have one with the refresh token, but that's not required
|
||||||
|
idToken, err := s.jwtService.GenerateIDToken(userClaims, input.ClientID, "")
|
||||||
|
if err != nil {
|
||||||
|
return CreatedTokens{}, err
|
||||||
|
}
|
||||||
|
|
||||||
// Generate a new refresh token and invalidate the old one
|
// Generate a new refresh token and invalidate the old one
|
||||||
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
|
newRefreshToken, err := s.createRefreshToken(ctx, input.ClientID, storedRefreshToken.UserID, storedRefreshToken.Scope, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -488,7 +538,8 @@ func (s *OidcService) createTokenFromRefreshToken(ctx context.Context, input dto
|
|||||||
return CreatedTokens{
|
return CreatedTokens{
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
RefreshToken: newRefreshToken,
|
RefreshToken: newRefreshToken,
|
||||||
ExpiresIn: time.Hour,
|
IdToken: idToken,
|
||||||
|
ExpiresIn: AccessTokenDuration,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -669,6 +720,11 @@ func (s *OidcService) ListClients(ctx context.Context, name string, sortedPagina
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCreateDto, userID string) (model.OidcClient, error) {
|
||||||
|
tx := s.db.Begin()
|
||||||
|
defer func() {
|
||||||
|
tx.Rollback()
|
||||||
|
}()
|
||||||
|
|
||||||
client := model.OidcClient{
|
client := model.OidcClient{
|
||||||
Base: model.Base{
|
Base: model.Base{
|
||||||
ID: input.ID,
|
ID: input.ID,
|
||||||
@@ -677,7 +733,7 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
|
|||||||
}
|
}
|
||||||
updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto)
|
updateOIDCClientModelFromDto(&client, &input.OidcClientUpdateDto)
|
||||||
|
|
||||||
err := s.db.
|
err := tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Create(&client).
|
Create(&client).
|
||||||
Error
|
Error
|
||||||
@@ -688,33 +744,11 @@ func (s *OidcService) CreateClient(ctx context.Context, input dto.OidcClientCrea
|
|||||||
return model.OidcClient{}, err
|
return model.OidcClient{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return client, nil
|
if input.LogoURL != nil {
|
||||||
}
|
err = s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL)
|
||||||
|
if err != nil {
|
||||||
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
|
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
||||||
tx := s.db.Begin()
|
}
|
||||||
defer func() {
|
|
||||||
tx.Rollback()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var client model.OidcClient
|
|
||||||
err := tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
Preload("CreatedBy").
|
|
||||||
First(&client, "id = ?", clientID).
|
|
||||||
Error
|
|
||||||
if err != nil {
|
|
||||||
return model.OidcClient{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
updateOIDCClientModelFromDto(&client, &input)
|
|
||||||
|
|
||||||
err = tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
Save(&client).
|
|
||||||
Error
|
|
||||||
if err != nil {
|
|
||||||
return model.OidcClient{}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit().Error
|
err = tx.Commit().Error
|
||||||
@@ -725,6 +759,36 @@ func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input d
|
|||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) UpdateClient(ctx context.Context, clientID string, input dto.OidcClientUpdateDto) (model.OidcClient, error) {
|
||||||
|
tx := s.db.Begin()
|
||||||
|
defer func() { tx.Rollback() }()
|
||||||
|
|
||||||
|
var client model.OidcClient
|
||||||
|
if err := tx.WithContext(ctx).
|
||||||
|
Preload("CreatedBy").
|
||||||
|
First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
|
return model.OidcClient{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
updateOIDCClientModelFromDto(&client, &input)
|
||||||
|
|
||||||
|
if err := tx.WithContext(ctx).Save(&client).Error; err != nil {
|
||||||
|
return model.OidcClient{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if input.LogoURL != nil {
|
||||||
|
err := s.downloadAndSaveLogoFromURL(ctx, tx, client.ID, *input.LogoURL)
|
||||||
|
if err != nil {
|
||||||
|
return model.OidcClient{}, fmt.Errorf("failed to download logo: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit().Error; err != nil {
|
||||||
|
return model.OidcClient{}, err
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientUpdateDto) {
|
func updateOIDCClientModelFromDto(client *model.OidcClient, input *dto.OidcClientUpdateDto) {
|
||||||
// Base fields
|
// Base fields
|
||||||
client.Name = input.Name
|
client.Name = input.Name
|
||||||
@@ -838,41 +902,14 @@ func (s *OidcService) UpdateClientLogo(ctx context.Context, clientID string, fil
|
|||||||
}
|
}
|
||||||
|
|
||||||
tx := s.db.Begin()
|
tx := s.db.Begin()
|
||||||
defer func() {
|
|
||||||
|
err = s.updateClientLogoType(ctx, tx, clientID, fileType)
|
||||||
|
if err != nil {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
}()
|
|
||||||
|
|
||||||
var client model.OidcClient
|
|
||||||
err = tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
First(&client, "id = ?", clientID).
|
|
||||||
Error
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if client.ImageType != nil && fileType != *client.ImageType {
|
return tx.Commit().Error
|
||||||
oldImagePath := fmt.Sprintf("%s/oidc-client-images/%s.%s", common.EnvConfig.UploadPath, client.ID, *client.ImageType)
|
|
||||||
if err := os.Remove(oldImagePath); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
client.ImageType = &fileType
|
|
||||||
err = tx.
|
|
||||||
WithContext(ctx).
|
|
||||||
Save(&client).
|
|
||||||
Error
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Commit().Error
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
|
func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) error {
|
||||||
@@ -896,6 +933,7 @@ func (s *OidcService) DeleteClientLogo(ctx context.Context, clientID string) err
|
|||||||
|
|
||||||
oldImageType := *client.ImageType
|
oldImageType := *client.ImageType
|
||||||
client.ImageType = nil
|
client.ImageType = nil
|
||||||
|
|
||||||
err = tx.
|
err = tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
Save(&client).
|
Save(&client).
|
||||||
@@ -1288,7 +1326,7 @@ func (s *OidcService) GetDeviceCodeInfo(ctx context.Context, userCode string, us
|
|||||||
Client: dto.OidcClientMetaDataDto{
|
Client: dto.OidcClientMetaDataDto{
|
||||||
ID: deviceAuth.Client.ID,
|
ID: deviceAuth.Client.ID,
|
||||||
Name: deviceAuth.Client.Name,
|
Name: deviceAuth.Client.Name,
|
||||||
HasLogo: deviceAuth.Client.HasLogo,
|
HasLogo: deviceAuth.Client.HasLogo(),
|
||||||
},
|
},
|
||||||
Scope: deviceAuth.Scope,
|
Scope: deviceAuth.Scope,
|
||||||
AuthorizationRequired: !hasAuthorizedClient,
|
AuthorizationRequired: !hasAuthorizedClient,
|
||||||
@@ -1383,14 +1421,18 @@ func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID stri
|
|||||||
|
|
||||||
// If user has no groups, only return clients with no allowed user groups
|
// If user has no groups, only return clients with no allowed user groups
|
||||||
if len(userGroupIDs) == 0 {
|
if len(userGroupIDs) == 0 {
|
||||||
query = query.
|
query = query.Where(`NOT EXISTS (
|
||||||
Joins("LEFT JOIN oidc_clients_allowed_user_groups ON oidc_clients.id = oidc_clients_allowed_user_groups.oidc_client_id").
|
SELECT 1 FROM oidc_clients_allowed_user_groups
|
||||||
Where("oidc_clients_allowed_user_groups.oidc_client_id IS NULL")
|
WHERE oidc_clients_allowed_user_groups.oidc_client_id = oidc_clients.id)`)
|
||||||
} else {
|
} else {
|
||||||
// Return clients with no allowed user groups OR clients where user is in allowed groups
|
query = query.Where(`
|
||||||
query = query.
|
NOT EXISTS (
|
||||||
Joins("LEFT JOIN oidc_clients_allowed_user_groups ON oidc_clients.id = oidc_clients_allowed_user_groups.oidc_client_id").
|
SELECT 1 FROM oidc_clients_allowed_user_groups
|
||||||
Where("oidc_clients_allowed_user_groups.oidc_client_id IS NULL OR oidc_clients_allowed_user_groups.user_group_id IN (?)", userGroupIDs)
|
WHERE oidc_clients_allowed_user_groups.oidc_client_id = oidc_clients.id
|
||||||
|
) OR EXISTS (
|
||||||
|
SELECT 1 FROM oidc_clients_allowed_user_groups
|
||||||
|
WHERE oidc_clients_allowed_user_groups.oidc_client_id = oidc_clients.id
|
||||||
|
AND oidc_clients_allowed_user_groups.user_group_id IN (?))`, userGroupIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
var clients []model.OidcClient
|
var clients []model.OidcClient
|
||||||
@@ -1419,7 +1461,7 @@ func (s *OidcService) ListAccessibleOidcClients(ctx context.Context, userID stri
|
|||||||
ID: client.ID,
|
ID: client.ID,
|
||||||
Name: client.Name,
|
Name: client.Name,
|
||||||
LaunchURL: client.LaunchURL,
|
LaunchURL: client.LaunchURL,
|
||||||
HasLogo: client.HasLogo,
|
HasLogo: client.HasLogo(),
|
||||||
},
|
},
|
||||||
LastUsedAt: lastUsedAt,
|
LastUsedAt: lastUsedAt,
|
||||||
}
|
}
|
||||||
@@ -1690,7 +1732,7 @@ func (s *OidcService) extractClientIDFromAssertion(assertion string) (string, er
|
|||||||
return sub, nil
|
return sub, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes string) (*dto.OidcClientPreviewDto, error) {
|
func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, userID string, scopes []string) (*dto.OidcClientPreviewDto, error) {
|
||||||
tx := s.db.Begin()
|
tx := s.db.Begin()
|
||||||
defer func() {
|
defer func() {
|
||||||
tx.Rollback()
|
tx.Rollback()
|
||||||
@@ -1715,14 +1757,7 @@ func (s *OidcService) GetClientPreview(ctx context.Context, clientID string, use
|
|||||||
return nil, &common.OidcAccessDeniedError{}
|
return nil, &common.OidcAccessDeniedError{}
|
||||||
}
|
}
|
||||||
|
|
||||||
dummyAuthorizedClient := model.UserAuthorizedOidcClient{
|
userClaims, err := s.getUserClaims(ctx, &user, scopes, tx)
|
||||||
UserID: userID,
|
|
||||||
ClientID: clientID,
|
|
||||||
Scope: scopes,
|
|
||||||
User: user,
|
|
||||||
}
|
|
||||||
|
|
||||||
userClaims, err := s.getUserClaimsFromAuthorizedClient(ctx, &dummyAuthorizedClient, tx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1775,14 +1810,10 @@ func (s *OidcService) getUserClaimsForClientInternal(ctx context.Context, userID
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.getUserClaimsFromAuthorizedClient(ctx, &authorizedOidcClient, tx)
|
return s.getUserClaims(ctx, &authorizedOidcClient.User, authorizedOidcClient.Scopes(), tx)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, authorizedClient *model.UserAuthorizedOidcClient, tx *gorm.DB) (map[string]any, error) {
|
func (s *OidcService) getUserClaims(ctx context.Context, user *model.User, scopes []string, tx *gorm.DB) (map[string]any, error) {
|
||||||
user := authorizedClient.User
|
|
||||||
scopes := strings.Split(authorizedClient.Scope, " ")
|
|
||||||
|
|
||||||
claims := make(map[string]any, 10)
|
claims := make(map[string]any, 10)
|
||||||
|
|
||||||
claims["sub"] = user.ID
|
claims["sub"] = user.ID
|
||||||
@@ -1800,13 +1831,6 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
|
|||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(scopes, "profile") {
|
if slices.Contains(scopes, "profile") {
|
||||||
// Add profile claims
|
|
||||||
claims["given_name"] = user.FirstName
|
|
||||||
claims["family_name"] = user.LastName
|
|
||||||
claims["name"] = user.FullName()
|
|
||||||
claims["preferred_username"] = user.Username
|
|
||||||
claims["picture"] = common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png"
|
|
||||||
|
|
||||||
// Add custom claims
|
// Add custom claims
|
||||||
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx)
|
customClaims, err := s.customClaimService.GetCustomClaimsForUserWithUserGroups(ctx, user.ID, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1825,6 +1849,15 @@ func (s *OidcService) getUserClaimsFromAuthorizedClient(ctx context.Context, aut
|
|||||||
claims[customClaim.Key] = customClaim.Value
|
claims[customClaim.Key] = customClaim.Value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add profile claims
|
||||||
|
claims["given_name"] = user.FirstName
|
||||||
|
claims["family_name"] = user.LastName
|
||||||
|
claims["name"] = user.FullName()
|
||||||
|
claims["display_name"] = user.DisplayName
|
||||||
|
|
||||||
|
claims["preferred_username"] = user.Username
|
||||||
|
claims["picture"] = common.EnvConfig.AppURL + "/api/users/" + user.ID + "/profile-picture.png"
|
||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(scopes, "email") {
|
if slices.Contains(scopes, "email") {
|
||||||
@@ -1849,3 +1882,93 @@ func (s *OidcService) IsClientAccessibleToUser(ctx context.Context, clientID str
|
|||||||
|
|
||||||
return s.IsUserGroupAllowedToAuthorize(user, client), nil
|
return s.IsUserGroupAllowedToAuthorize(user, client), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) downloadAndSaveLogoFromURL(parentCtx context.Context, tx *gorm.DB, clientID string, raw string) error {
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(parentCtx, 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
r := net.Resolver{}
|
||||||
|
ips, err := r.LookupIPAddr(ctx, u.Hostname())
|
||||||
|
if err != nil || len(ips) == 0 {
|
||||||
|
return fmt.Errorf("cannot resolve hostname")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prevents SSRF by allowing only public IPs
|
||||||
|
for _, addr := range ips {
|
||||||
|
if utils.IsPrivateIP(addr.IP) {
|
||||||
|
return fmt.Errorf("private IP addresses are not allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, raw, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", "pocket-id/oidc-logo-fetcher")
|
||||||
|
req.Header.Set("Accept", "image/*")
|
||||||
|
|
||||||
|
resp, err := s.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("failed to fetch logo: %s", resp.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxLogoSize int64 = 2 * 1024 * 1024 // 2MB
|
||||||
|
if resp.ContentLength > maxLogoSize {
|
||||||
|
return fmt.Errorf("logo is too large")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer extension in path if supported
|
||||||
|
ext := utils.GetFileExtension(u.Path)
|
||||||
|
if ext == "" || utils.GetImageMimeType(ext) == "" {
|
||||||
|
// Otherwise, try to detect from content type
|
||||||
|
ext = utils.GetImageExtensionFromMimeType(resp.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if ext == "" {
|
||||||
|
return &common.FileTypeNotSupportedError{}
|
||||||
|
}
|
||||||
|
|
||||||
|
folderPath := filepath.Join(common.EnvConfig.UploadPath, "oidc-client-images")
|
||||||
|
err = os.MkdirAll(folderPath, os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
imagePath := filepath.Join(folderPath, clientID+"."+ext)
|
||||||
|
err = utils.SaveFileStream(io.LimitReader(resp.Body, maxLogoSize+1), imagePath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.updateClientLogoType(ctx, tx, clientID, ext); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OidcService) updateClientLogoType(ctx context.Context, tx *gorm.DB, clientID, ext string) error {
|
||||||
|
uploadsDir := common.EnvConfig.UploadPath + "/oidc-client-images"
|
||||||
|
|
||||||
|
var client model.OidcClient
|
||||||
|
if err := tx.WithContext(ctx).First(&client, "id = ?", clientID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if client.ImageType != nil && *client.ImageType != ext {
|
||||||
|
old := fmt.Sprintf("%s/%s.%s", uploadsDir, client.ID, *client.ImageType)
|
||||||
|
_ = os.Remove(old)
|
||||||
|
}
|
||||||
|
client.ImageType = &ext
|
||||||
|
return tx.WithContext(ctx).Save(&client).Error
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/common"
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
"github.com/pocket-id/pocket-id/backend/internal/dto"
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
testutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -148,6 +149,13 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
privateJWKDefaults, jwkSetJSONDefaults := generateTestECDSAKey(t)
|
privateJWKDefaults, jwkSetJSONDefaults := generateTestECDSAKey(t)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a mock config and JwtService to test complete a token creation process
|
||||||
|
mockConfig := NewTestAppConfigService(&model.AppConfig{
|
||||||
|
SessionDuration: model.AppConfigVariable{Value: "60"}, // 60 minutes
|
||||||
|
})
|
||||||
|
mockJwtService, err := NewJwtService(db, mockConfig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Create a mock HTTP client with custom transport to return the JWKS
|
// Create a mock HTTP client with custom transport to return the JWKS
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Transport: &testutils.MockRoundTripper{
|
Transport: &testutils.MockRoundTripper{
|
||||||
@@ -162,8 +170,10 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
|
|
||||||
// Init the OidcService
|
// Init the OidcService
|
||||||
s := &OidcService{
|
s := &OidcService{
|
||||||
db: db,
|
db: db,
|
||||||
httpClient: httpClient,
|
jwtService: mockJwtService,
|
||||||
|
appConfigService: mockConfig,
|
||||||
|
httpClient: httpClient,
|
||||||
}
|
}
|
||||||
s.jwkCache, err = s.getJWKCache(t.Context())
|
s.jwkCache, err = s.getJWKCache(t.Context())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -384,4 +394,119 @@ func TestOidcService_verifyClientCredentialsInternal(t *testing.T) {
|
|||||||
assert.Equal(t, federatedClient.ID, client.ID)
|
assert.Equal(t, federatedClient.ID, client.ID)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("Complete token creation flow", func(t *testing.T) {
|
||||||
|
t.Run("Client Credentials flow", func(t *testing.T) {
|
||||||
|
t.Run("Succeeds with valid secret", func(t *testing.T) {
|
||||||
|
// Generate a token
|
||||||
|
input := dto.OidcCreateTokensDto{
|
||||||
|
ClientID: confidentialClient.ID,
|
||||||
|
ClientSecret: confidentialSecret,
|
||||||
|
}
|
||||||
|
token, err := s.createTokenFromClientCredentials(t.Context(), input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, token)
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := s.jwtService.VerifyOAuthAccessToken(token.AccessToken)
|
||||||
|
require.NoError(t, err, "Failed to verify generated token")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, "client-"+confidentialClient.ID, subject, "Token subject should match confidential client ID with prefix")
|
||||||
|
audience, ok := claims.Audience()
|
||||||
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||||
|
assert.Equal(t, []string{confidentialClient.ID}, audience, "Audience should contain confidential client ID")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Fails with invalid secret", func(t *testing.T) {
|
||||||
|
input := dto.OidcCreateTokensDto{
|
||||||
|
ClientID: confidentialClient.ID,
|
||||||
|
ClientSecret: "invalid-secret",
|
||||||
|
}
|
||||||
|
_, err := s.createTokenFromClientCredentials(t.Context(), input)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, &common.OidcClientSecretInvalidError{})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Fails without client secret for public clients", func(t *testing.T) {
|
||||||
|
input := dto.OidcCreateTokensDto{
|
||||||
|
ClientID: publicClient.ID,
|
||||||
|
}
|
||||||
|
_, err := s.createTokenFromClientCredentials(t.Context(), input)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, &common.OidcMissingClientCredentialsError{})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Succeeds with valid assertion", func(t *testing.T) {
|
||||||
|
// Create JWT for federated identity
|
||||||
|
token, err := jwt.NewBuilder().
|
||||||
|
Issuer(federatedClientIssuer).
|
||||||
|
Audience([]string{federatedClientAudience}).
|
||||||
|
Subject(federatedClient.ID).
|
||||||
|
IssuedAt(time.Now()).
|
||||||
|
Expiration(time.Now().Add(10 * time.Minute)).
|
||||||
|
Build()
|
||||||
|
require.NoError(t, err)
|
||||||
|
signedToken, err := jwt.Sign(token, jwt.WithKey(jwa.ES256(), privateJWK))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Generate a token
|
||||||
|
input := dto.OidcCreateTokensDto{
|
||||||
|
ClientAssertion: string(signedToken),
|
||||||
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||||
|
}
|
||||||
|
createdToken, err := s.createTokenFromClientCredentials(t.Context(), input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, token)
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := s.jwtService.VerifyOAuthAccessToken(createdToken.AccessToken)
|
||||||
|
require.NoError(t, err, "Failed to verify generated token")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, "client-"+federatedClient.ID, subject, "Token subject should match federated client ID with prefix")
|
||||||
|
audience, ok := claims.Audience()
|
||||||
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||||
|
assert.Equal(t, []string{federatedClient.ID}, audience, "Audience should contain the federated client ID")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Fails with invalid assertion", func(t *testing.T) {
|
||||||
|
input := dto.OidcCreateTokensDto{
|
||||||
|
ClientAssertion: "invalid.jwt.token",
|
||||||
|
ClientAssertionType: ClientAssertionTypeJWTBearer,
|
||||||
|
}
|
||||||
|
_, err := s.createTokenFromClientCredentials(t.Context(), input)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, &common.OidcClientAssertionInvalidError{})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Succeeds with custom resource", func(t *testing.T) {
|
||||||
|
// Generate a token
|
||||||
|
input := dto.OidcCreateTokensDto{
|
||||||
|
ClientID: confidentialClient.ID,
|
||||||
|
ClientSecret: confidentialSecret,
|
||||||
|
Resource: "https://example.com/",
|
||||||
|
}
|
||||||
|
token, err := s.createTokenFromClientCredentials(t.Context(), input)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, token)
|
||||||
|
|
||||||
|
// Verify the token
|
||||||
|
claims, err := s.jwtService.VerifyOAuthAccessToken(token.AccessToken)
|
||||||
|
require.NoError(t, err, "Failed to verify generated token")
|
||||||
|
|
||||||
|
// Check the claims
|
||||||
|
subject, ok := claims.Subject()
|
||||||
|
_ = assert.True(t, ok, "User ID not found in token") &&
|
||||||
|
assert.Equal(t, "client-"+confidentialClient.ID, subject, "Token subject should match confidential client ID with prefix")
|
||||||
|
audience, ok := claims.Audience()
|
||||||
|
_ = assert.True(t, ok, "Audience not found in token") &&
|
||||||
|
assert.Equal(t, []string{input.Resource}, audience, "Audience should contain the resource provided in request")
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -244,13 +244,18 @@ func (s *UserService) CreateUser(ctx context.Context, input dto.UserCreateDto) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, isLdapSync bool, tx *gorm.DB) (model.User, error) {
|
func (s *UserService) createUserInternal(ctx context.Context, input dto.UserCreateDto, isLdapSync bool, tx *gorm.DB) (model.User, error) {
|
||||||
|
if s.appConfigService.GetDbConfig().RequireUserEmail.IsTrue() && input.Email == nil {
|
||||||
|
return model.User{}, &common.UserEmailNotSetError{}
|
||||||
|
}
|
||||||
|
|
||||||
user := model.User{
|
user := model.User{
|
||||||
FirstName: input.FirstName,
|
FirstName: input.FirstName,
|
||||||
LastName: input.LastName,
|
LastName: input.LastName,
|
||||||
Email: input.Email,
|
DisplayName: input.DisplayName,
|
||||||
Username: input.Username,
|
Email: input.Email,
|
||||||
IsAdmin: input.IsAdmin,
|
Username: input.Username,
|
||||||
Locale: input.Locale,
|
IsAdmin: input.IsAdmin,
|
||||||
|
Locale: input.Locale,
|
||||||
}
|
}
|
||||||
if input.LdapID != "" {
|
if input.LdapID != "" {
|
||||||
user.LdapID = &input.LdapID
|
user.LdapID = &input.LdapID
|
||||||
@@ -338,6 +343,10 @@ func (s *UserService) UpdateUser(ctx context.Context, userID string, updatedUser
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, isLdapSync bool, tx *gorm.DB) (model.User, error) {
|
func (s *UserService) updateUserInternal(ctx context.Context, userID string, updatedUser dto.UserCreateDto, updateOwnUser bool, isLdapSync bool, tx *gorm.DB) (model.User, error) {
|
||||||
|
if s.appConfigService.GetDbConfig().RequireUserEmail.IsTrue() && updatedUser.Email == nil {
|
||||||
|
return model.User{}, &common.UserEmailNotSetError{}
|
||||||
|
}
|
||||||
|
|
||||||
var user model.User
|
var user model.User
|
||||||
err := tx.
|
err := tx.
|
||||||
WithContext(ctx).
|
WithContext(ctx).
|
||||||
@@ -362,6 +371,7 @@ func (s *UserService) updateUserInternal(ctx context.Context, userID string, upd
|
|||||||
// Full update: Allow updating all personal fields
|
// Full update: Allow updating all personal fields
|
||||||
user.FirstName = updatedUser.FirstName
|
user.FirstName = updatedUser.FirstName
|
||||||
user.LastName = updatedUser.LastName
|
user.LastName = updatedUser.LastName
|
||||||
|
user.DisplayName = updatedUser.DisplayName
|
||||||
user.Email = updatedUser.Email
|
user.Email = updatedUser.Email
|
||||||
user.Username = updatedUser.Username
|
user.Username = updatedUser.Username
|
||||||
user.Locale = updatedUser.Locale
|
user.Locale = updatedUser.Locale
|
||||||
@@ -435,6 +445,10 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if user.Email == nil {
|
||||||
|
return &common.UserEmailNotSetError{}
|
||||||
|
}
|
||||||
|
|
||||||
oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, ttl, tx)
|
oneTimeAccessToken, err := s.createOneTimeAccessTokenInternal(ctx, user.ID, ttl, tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -462,7 +476,7 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
|
|||||||
|
|
||||||
errInternal := SendEmail(innerCtx, s.emailService, email.Address{
|
errInternal := SendEmail(innerCtx, s.emailService, email.Address{
|
||||||
Name: user.FullName(),
|
Name: user.FullName(),
|
||||||
Email: user.Email,
|
Email: *user.Email,
|
||||||
}, OneTimeAccessTemplate, &OneTimeAccessTemplateData{
|
}, OneTimeAccessTemplate, &OneTimeAccessTemplateData{
|
||||||
Code: oneTimeAccessToken,
|
Code: oneTimeAccessToken,
|
||||||
LoginLink: link,
|
LoginLink: link,
|
||||||
@@ -470,7 +484,7 @@ func (s *UserService) requestOneTimeAccessEmailInternal(ctx context.Context, use
|
|||||||
ExpirationString: utils.DurationToString(ttl),
|
ExpirationString: utils.DurationToString(ttl),
|
||||||
})
|
})
|
||||||
if errInternal != nil {
|
if errInternal != nil {
|
||||||
slog.ErrorContext(innerCtx, "Failed to send one-time access token email", slog.Any("error", errInternal), slog.String("address", user.Email))
|
slog.ErrorContext(innerCtx, "Failed to send one-time access token email", slog.Any("error", errInternal), slog.String("address", *user.Email))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -600,11 +614,12 @@ func (s *UserService) SignUpInitialAdmin(ctx context.Context, signUpData dto.Sig
|
|||||||
}
|
}
|
||||||
|
|
||||||
userToCreate := dto.UserCreateDto{
|
userToCreate := dto.UserCreateDto{
|
||||||
FirstName: signUpData.FirstName,
|
FirstName: signUpData.FirstName,
|
||||||
LastName: signUpData.LastName,
|
LastName: signUpData.LastName,
|
||||||
Username: signUpData.Username,
|
DisplayName: strings.TrimSpace(signUpData.FirstName + " " + signUpData.LastName),
|
||||||
Email: signUpData.Email,
|
Username: signUpData.Username,
|
||||||
IsAdmin: true,
|
Email: signUpData.Email,
|
||||||
|
IsAdmin: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
|
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
|
||||||
@@ -736,10 +751,11 @@ func (s *UserService) SignUp(ctx context.Context, signupData dto.SignUpDto, ipAd
|
|||||||
}
|
}
|
||||||
|
|
||||||
userToCreate := dto.UserCreateDto{
|
userToCreate := dto.UserCreateDto{
|
||||||
Username: signupData.Username,
|
Username: signupData.Username,
|
||||||
Email: signupData.Email,
|
Email: signupData.Email,
|
||||||
FirstName: signupData.FirstName,
|
FirstName: signupData.FirstName,
|
||||||
LastName: signupData.LastName,
|
LastName: signupData.LastName,
|
||||||
|
DisplayName: strings.TrimSpace(signupData.FirstName + " " + signupData.LastName),
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
|
user, err := s.createUserInternal(ctx, userToCreate, false, tx)
|
||||||
|
|||||||
74
backend/internal/service/version_service.go
Normal file
74
backend/internal/service/version_service.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
versionTTL = 15 * time.Minute
|
||||||
|
versionCheckURL = "https://api.github.com/repos/pocket-id/pocket-id/releases/latest"
|
||||||
|
)
|
||||||
|
|
||||||
|
type VersionService struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
cache *utils.Cache[string]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewVersionService(httpClient *http.Client) *VersionService {
|
||||||
|
return &VersionService{
|
||||||
|
httpClient: httpClient,
|
||||||
|
cache: utils.New[string](versionTTL),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *VersionService) GetLatestVersion(ctx context.Context) (string, error) {
|
||||||
|
version, err := s.cache.GetOrFetch(ctx, func(ctx context.Context) (string, error) {
|
||||||
|
reqCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, versionCheckURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("create GitHub request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("get latest tag: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("GitHub API returned status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
TagName string `json:"tag_name"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
|
||||||
|
return "", fmt.Errorf("decode payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload.TagName == "" {
|
||||||
|
return "", fmt.Errorf("GitHub API returned empty tag name")
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimPrefix(payload.TagName, "v"), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
var staleErr *utils.ErrStale
|
||||||
|
if errors.As(err, &staleErr) {
|
||||||
|
slog.Warn("Failed to fetch latest version, returning stale cache", "error", staleErr.Err)
|
||||||
|
return version, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return version, err
|
||||||
|
}
|
||||||
78
backend/internal/utils/cache_util.go
Normal file
78
backend/internal/utils/cache_util.go
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CacheEntry[T any] struct {
|
||||||
|
Value T
|
||||||
|
FetchedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrStale struct {
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ErrStale) Error() string { return "returned stale cache: " + e.Err.Error() }
|
||||||
|
func (e *ErrStale) Unwrap() error { return e.Err }
|
||||||
|
|
||||||
|
type Cache[T any] struct {
|
||||||
|
ttl time.Duration
|
||||||
|
entry atomic.Pointer[CacheEntry[T]]
|
||||||
|
sf singleflight.Group
|
||||||
|
}
|
||||||
|
|
||||||
|
func New[T any](ttl time.Duration) *Cache[T] {
|
||||||
|
return &Cache[T]{ttl: ttl}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the cached value if it's still fresh.
|
||||||
|
func (c *Cache[T]) Get() (T, bool) {
|
||||||
|
entry := c.entry.Load()
|
||||||
|
if entry == nil {
|
||||||
|
var zero T
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
if time.Since(entry.FetchedAt) < c.ttl {
|
||||||
|
return entry.Value, true
|
||||||
|
}
|
||||||
|
var zero T
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrFetch returns the cached value if it's still fresh, otherwise calls fetch to get a new value.
|
||||||
|
func (c *Cache[T]) GetOrFetch(ctx context.Context, fetch func(context.Context) (T, error)) (T, error) {
|
||||||
|
// If fresh, serve immediately
|
||||||
|
if v, ok := c.Get(); ok {
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch with singleflight to prevent multiple concurrent fetches
|
||||||
|
vAny, err, _ := c.sf.Do("singleton", func() (any, error) {
|
||||||
|
if v2, ok := c.Get(); ok {
|
||||||
|
return v2, nil
|
||||||
|
}
|
||||||
|
val, fetchErr := fetch(ctx)
|
||||||
|
if fetchErr != nil {
|
||||||
|
return nil, fetchErr
|
||||||
|
}
|
||||||
|
c.entry.Store(&CacheEntry[T]{Value: val, FetchedAt: time.Now()})
|
||||||
|
return val, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
return vAny.(T), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch failed. Return stale if possible.
|
||||||
|
if e := c.entry.Load(); e != nil {
|
||||||
|
return e.Value, &ErrStale{Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
var zero T
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
@@ -3,7 +3,6 @@ package email
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
htemplate "html/template"
|
htemplate "html/template"
|
||||||
"io/fs"
|
|
||||||
"path"
|
"path"
|
||||||
ttemplate "text/template"
|
ttemplate "text/template"
|
||||||
|
|
||||||
@@ -27,71 +26,35 @@ func GetTemplate[U any, V any](templateMap TemplateMap[U], template Template[V])
|
|||||||
return templateMap[template.Path]
|
return templateMap[template.Path]
|
||||||
}
|
}
|
||||||
|
|
||||||
type cloneable[V pareseable[V]] interface {
|
|
||||||
Clone() (V, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type pareseable[V any] interface {
|
|
||||||
ParseFS(fs.FS, ...string) (V, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareTemplate[V pareseable[V]](templateFS fs.FS, template string, rootTemplate cloneable[V], suffix string) (V, error) {
|
|
||||||
tmpl, err := rootTemplate.Clone()
|
|
||||||
if err != nil {
|
|
||||||
return *new(V), fmt.Errorf("clone root template: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
filename := fmt.Sprintf("%s%s", template, suffix)
|
|
||||||
templatePath := path.Join("email-templates", filename)
|
|
||||||
_, err = tmpl.ParseFS(templateFS, templatePath)
|
|
||||||
if err != nil {
|
|
||||||
return *new(V), fmt.Errorf("parsing template '%s': %w", template, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tmpl, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func PrepareTextTemplates(templates []string) (map[string]*ttemplate.Template, error) {
|
func PrepareTextTemplates(templates []string) (map[string]*ttemplate.Template, error) {
|
||||||
components := path.Join("email-templates", "components", "*_text.tmpl")
|
|
||||||
rootTmpl, err := ttemplate.ParseFS(resources.FS, components)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse templates '%s': %w", components, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
textTemplates := make(map[string]*ttemplate.Template, len(templates))
|
textTemplates := make(map[string]*ttemplate.Template, len(templates))
|
||||||
for _, tmpl := range templates {
|
for _, tmpl := range templates {
|
||||||
rootTmplClone, err := rootTmpl.Clone()
|
filename := tmpl + "_text.tmpl"
|
||||||
|
templatePath := path.Join("email-templates", filename)
|
||||||
|
|
||||||
|
parsedTemplate, err := ttemplate.ParseFS(resources.FS, templatePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("clone root template: %w", err)
|
return nil, fmt.Errorf("parsing template '%s': %w", tmpl, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
textTemplates[tmpl], err = prepareTemplate[*ttemplate.Template](resources.FS, tmpl, rootTmplClone, "_text.tmpl")
|
textTemplates[tmpl] = parsedTemplate
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse '%s': %w", tmpl, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return textTemplates, nil
|
return textTemplates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func PrepareHTMLTemplates(templates []string) (map[string]*htemplate.Template, error) {
|
func PrepareHTMLTemplates(templates []string) (map[string]*htemplate.Template, error) {
|
||||||
components := path.Join("email-templates", "components", "*_html.tmpl")
|
|
||||||
rootTmpl, err := htemplate.ParseFS(resources.FS, components)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to parse templates '%s': %w", components, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
htmlTemplates := make(map[string]*htemplate.Template, len(templates))
|
htmlTemplates := make(map[string]*htemplate.Template, len(templates))
|
||||||
for _, tmpl := range templates {
|
for _, tmpl := range templates {
|
||||||
rootTmplClone, err := rootTmpl.Clone()
|
filename := tmpl + "_html.tmpl"
|
||||||
|
templatePath := path.Join("email-templates", filename)
|
||||||
|
|
||||||
|
parsedTemplate, err := htemplate.ParseFS(resources.FS, templatePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("clone root template: %w", err)
|
return nil, fmt.Errorf("parsing template '%s': %w", tmpl, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
htmlTemplates[tmpl], err = prepareTemplate[*htemplate.Template](resources.FS, tmpl, rootTmplClone, "_html.tmpl")
|
htmlTemplates[tmpl] = parsedTemplate
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse '%s': %w", tmpl, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return htmlTemplates, nil
|
return htmlTemplates, nil
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"mime"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -23,6 +26,15 @@ func GetFileExtension(filename string) string {
|
|||||||
return filename
|
return filename
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SplitFileName splits a full file name into name and extension.
|
||||||
|
func SplitFileName(fullName string) (name, ext string) {
|
||||||
|
dot := strings.LastIndex(fullName, ".")
|
||||||
|
if dot == -1 || dot == 0 {
|
||||||
|
return fullName, "" // no extension or hidden file like .gitignore
|
||||||
|
}
|
||||||
|
return fullName[:dot], fullName[dot+1:]
|
||||||
|
}
|
||||||
|
|
||||||
func GetImageMimeType(ext string) string {
|
func GetImageMimeType(ext string) string {
|
||||||
switch ext {
|
switch ext {
|
||||||
case "jpg", "jpeg":
|
case "jpg", "jpeg":
|
||||||
@@ -35,6 +47,40 @@ func GetImageMimeType(ext string) string {
|
|||||||
return "image/x-icon"
|
return "image/x-icon"
|
||||||
case "gif":
|
case "gif":
|
||||||
return "image/gif"
|
return "image/gif"
|
||||||
|
case "webp":
|
||||||
|
return "image/webp"
|
||||||
|
case "avif":
|
||||||
|
return "image/avif"
|
||||||
|
case "heic":
|
||||||
|
return "image/heic"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetImageExtensionFromMimeType(mimeType string) string {
|
||||||
|
// Normalize and strip parameters like `; charset=utf-8`
|
||||||
|
mt := strings.TrimSpace(strings.ToLower(mimeType))
|
||||||
|
if v, _, err := mime.ParseMediaType(mt); err == nil {
|
||||||
|
mt = v
|
||||||
|
}
|
||||||
|
switch mt {
|
||||||
|
case "image/jpeg", "image/jpg":
|
||||||
|
return "jpg"
|
||||||
|
case "image/png":
|
||||||
|
return "png"
|
||||||
|
case "image/svg+xml":
|
||||||
|
return "svg"
|
||||||
|
case "image/x-icon", "image/vnd.microsoft.icon":
|
||||||
|
return "ico"
|
||||||
|
case "image/gif":
|
||||||
|
return "gif"
|
||||||
|
case "image/webp":
|
||||||
|
return "webp"
|
||||||
|
case "image/avif":
|
||||||
|
return "avif"
|
||||||
|
case "image/heic", "image/heif":
|
||||||
|
return "heic"
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -43,29 +89,45 @@ func GetImageMimeType(ext string) string {
|
|||||||
func CopyEmbeddedFileToDisk(srcFilePath, destFilePath string) error {
|
func CopyEmbeddedFileToDisk(srcFilePath, destFilePath string) error {
|
||||||
srcFile, err := resources.FS.Open(srcFilePath)
|
srcFile, err := resources.FS.Open(srcFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to open embedded file: %w", err)
|
||||||
}
|
}
|
||||||
defer srcFile.Close()
|
defer srcFile.Close()
|
||||||
|
|
||||||
err = os.MkdirAll(filepath.Dir(destFilePath), os.ModePerm)
|
err = os.MkdirAll(filepath.Dir(destFilePath), os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to create destination directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
destFile, err := os.Create(destFilePath)
|
destFile, err := os.Create(destFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to open destination file: %w", err)
|
||||||
}
|
}
|
||||||
defer destFile.Close()
|
defer destFile.Close()
|
||||||
|
|
||||||
_, err = io.Copy(destFile, srcFile)
|
_, err = io.Copy(destFile, srcFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("failed to write to destination file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func EmbeddedFileSha256(filePath string) ([]byte, error) {
|
||||||
|
f, err := resources.FS.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open embedded file: %w", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
h := sha256.New()
|
||||||
|
_, err = io.Copy(h, f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read embedded file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.Sum(nil), nil
|
||||||
|
}
|
||||||
|
|
||||||
func SaveFile(file *multipart.FileHeader, dst string) error {
|
func SaveFile(file *multipart.FileHeader, dst string) error {
|
||||||
src, err := file.Open()
|
src, err := file.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -2,8 +2,36 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestSplitFileName(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
fullName string
|
||||||
|
wantName string
|
||||||
|
wantExt string
|
||||||
|
}{
|
||||||
|
{"background.jpg", "background", "jpg"},
|
||||||
|
{"archive.tar.gz", "archive.tar", "gz"},
|
||||||
|
{".gitignore", ".gitignore", ""},
|
||||||
|
{"noext", "noext", ""},
|
||||||
|
{"a.b.c", "a.b", "c"},
|
||||||
|
{".hidden.ext", ".hidden", "ext"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.fullName, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
name, ext := SplitFileName(tc.fullName)
|
||||||
|
assert.Equal(t, tc.wantName, name)
|
||||||
|
assert.Equal(t, tc.wantExt, ext)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetFileExtension(t *testing.T) {
|
func TestGetFileExtension(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -3,9 +3,28 @@ package utils
|
|||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
func CreateSha256Hash(input string) string {
|
func CreateSha256Hash(input string) string {
|
||||||
hash := sha256.Sum256([]byte(input))
|
hash := sha256.Sum256([]byte(input))
|
||||||
return hex.EncodeToString(hash[:])
|
return hex.EncodeToString(hash[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CreateSha256FileHash(filePath string) ([]byte, error) {
|
||||||
|
f, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open file: %w", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
h := sha256.New()
|
||||||
|
_, err = io.Copy(h, f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h.Sum(nil), nil
|
||||||
|
}
|
||||||
|
|||||||
87
backend/internal/utils/ip_util.go
Normal file
87
backend/internal/utils/ip_util.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
var localIPv6Ranges []*net.IPNet
|
||||||
|
|
||||||
|
var localhostIPNets = []*net.IPNet{
|
||||||
|
{IP: net.IPv4(127, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 127.0.0.0/8
|
||||||
|
{IP: net.IPv6loopback, Mask: net.CIDRMask(128, 128)}, // ::1/128
|
||||||
|
}
|
||||||
|
|
||||||
|
var privateLanIPNets = []*net.IPNet{
|
||||||
|
{IP: net.IPv4(10, 0, 0, 0), Mask: net.CIDRMask(8, 32)}, // 10.0.0.0/8
|
||||||
|
{IP: net.IPv4(172, 16, 0, 0), Mask: net.CIDRMask(12, 32)}, // 172.16.0.0/12
|
||||||
|
{IP: net.IPv4(192, 168, 0, 0), Mask: net.CIDRMask(16, 32)}, // 192.168.0.0/16
|
||||||
|
}
|
||||||
|
|
||||||
|
var tailscaleIPNets = []*net.IPNet{
|
||||||
|
{IP: net.IPv4(100, 64, 0, 0), Mask: net.CIDRMask(10, 32)}, // 100.64.0.0/10
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsLocalIPv6(ip net.IP) bool {
|
||||||
|
if ip.To4() != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return listContainsIP(localIPv6Ranges, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsLocalhostIP(ip net.IP) bool {
|
||||||
|
return listContainsIP(localhostIPNets, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsPrivateLanIP(ip net.IP) bool {
|
||||||
|
if ip.To4() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return listContainsIP(privateLanIPNets, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsTailscaleIP(ip net.IP) bool {
|
||||||
|
if ip.To4() == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return listContainsIP(tailscaleIPNets, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsPrivateIP(ip net.IP) bool {
|
||||||
|
return IsLocalhostIP(ip) || IsPrivateLanIP(ip) || IsTailscaleIP(ip) || IsLocalIPv6(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func listContainsIP(ipNets []*net.IPNet, ip net.IP) bool {
|
||||||
|
for _, ipNet := range ipNets {
|
||||||
|
if ipNet.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadLocalIPv6Ranges() {
|
||||||
|
localIPv6Ranges = nil
|
||||||
|
ranges := strings.Split(common.EnvConfig.LocalIPv6Ranges, ",")
|
||||||
|
|
||||||
|
for _, rangeStr := range ranges {
|
||||||
|
rangeStr = strings.TrimSpace(rangeStr)
|
||||||
|
if rangeStr == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ipNet, err := net.ParseCIDR(rangeStr)
|
||||||
|
if err == nil {
|
||||||
|
localIPv6Ranges = append(localIPv6Ranges, ipNet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
loadLocalIPv6Ranges()
|
||||||
|
}
|
||||||
159
backend/internal/utils/ip_util_test.go
Normal file
159
backend/internal/utils/ip_util_test.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pocket-id/pocket-id/backend/internal/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsLocalhostIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"127.0.0.1", true},
|
||||||
|
{"127.255.255.255", true},
|
||||||
|
{"::1", true},
|
||||||
|
{"192.168.1.1", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
ip := net.ParseIP(tt.ip)
|
||||||
|
if got := IsLocalhostIP(ip); got != tt.expected {
|
||||||
|
t.Errorf("IsLocalhostIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPrivateLanIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"10.0.0.1", true},
|
||||||
|
{"172.16.5.4", true},
|
||||||
|
{"192.168.100.200", true},
|
||||||
|
{"8.8.8.8", false},
|
||||||
|
{"::1", false}, // IPv6 should return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
ip := net.ParseIP(tt.ip)
|
||||||
|
if got := IsPrivateLanIP(ip); got != tt.expected {
|
||||||
|
t.Errorf("IsPrivateLanIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTailscaleIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"100.64.0.1", true},
|
||||||
|
{"100.127.255.254", true},
|
||||||
|
{"8.8.8.8", false},
|
||||||
|
{"::1", false}, // IPv6 should return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
ip := net.ParseIP(tt.ip)
|
||||||
|
if got := IsTailscaleIP(ip); got != tt.expected {
|
||||||
|
t.Errorf("IsTailscaleIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsLocalIPv6(t *testing.T) {
|
||||||
|
// Save and restore env config
|
||||||
|
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||||
|
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
||||||
|
|
||||||
|
common.EnvConfig.LocalIPv6Ranges = "fd00::/8,fc00::/7"
|
||||||
|
localIPv6Ranges = nil // reset
|
||||||
|
loadLocalIPv6Ranges()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"fd00::1", true},
|
||||||
|
{"fc00::abcd", true},
|
||||||
|
{"::1", false}, // loopback handled separately
|
||||||
|
{"192.168.1.1", false}, // IPv4 should return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
ip := net.ParseIP(tt.ip)
|
||||||
|
if got := IsLocalIPv6(ip); got != tt.expected {
|
||||||
|
t.Errorf("IsLocalIPv6(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsPrivateIP(t *testing.T) {
|
||||||
|
// Save and restore env config
|
||||||
|
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||||
|
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
||||||
|
|
||||||
|
common.EnvConfig.LocalIPv6Ranges = "fd00::/8"
|
||||||
|
localIPv6Ranges = nil // reset
|
||||||
|
loadLocalIPv6Ranges()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"127.0.0.1", true}, // localhost
|
||||||
|
{"192.168.1.1", true}, // private LAN
|
||||||
|
{"100.64.0.1", true}, // Tailscale
|
||||||
|
{"fd00::1", true}, // local IPv6
|
||||||
|
{"8.8.8.8", false}, // public IPv4
|
||||||
|
{"2001:4860:4860::8888", false}, // public IPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
ip := net.ParseIP(tt.ip)
|
||||||
|
if got := IsPrivateIP(ip); got != tt.expected {
|
||||||
|
t.Errorf("IsPrivateIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListContainsIP(t *testing.T) {
|
||||||
|
_, ipNet1, _ := net.ParseCIDR("10.0.0.0/8")
|
||||||
|
_, ipNet2, _ := net.ParseCIDR("192.168.0.0/16")
|
||||||
|
|
||||||
|
list := []*net.IPNet{ipNet1, ipNet2}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"10.1.1.1", true},
|
||||||
|
{"192.168.5.5", true},
|
||||||
|
{"172.16.0.1", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
ip := net.ParseIP(tt.ip)
|
||||||
|
if got := listContainsIP(list, ip); got != tt.expected {
|
||||||
|
t.Errorf("listContainsIP(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInit_LocalIPv6Ranges(t *testing.T) {
|
||||||
|
// Save and restore env config
|
||||||
|
origRanges := common.EnvConfig.LocalIPv6Ranges
|
||||||
|
defer func() { common.EnvConfig.LocalIPv6Ranges = origRanges }()
|
||||||
|
|
||||||
|
common.EnvConfig.LocalIPv6Ranges = "fd00::/8, invalidCIDR ,fc00::/7"
|
||||||
|
localIPv6Ranges = nil
|
||||||
|
loadLocalIPv6Ranges()
|
||||||
|
|
||||||
|
if len(localIPv6Ranges) != 2 {
|
||||||
|
t.Errorf("expected 2 valid IPv6 ranges, got %d", len(localIPv6Ranges))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/model"
|
"github.com/pocket-id/pocket-id/backend/internal/model"
|
||||||
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
cryptoutils "github.com/pocket-id/pocket-id/backend/internal/utils/crypto"
|
||||||
@@ -95,7 +96,14 @@ func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error {
|
|||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err = f.db.WithContext(ctx).Create(&row).Error
|
err = f.db.
|
||||||
|
WithContext(ctx).
|
||||||
|
Clauses(clause.OnConflict{
|
||||||
|
Columns: []clause.Column{{Name: "key"}},
|
||||||
|
DoUpdates: clause.AssignmentColumns([]string{"value"}),
|
||||||
|
}).
|
||||||
|
Create(&row).
|
||||||
|
Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// There's one scenario where if Pocket ID is started fresh with more than 1 replica, they both could be trying to create the private key in the database at the same time
|
// There's one scenario where if Pocket ID is started fresh with more than 1 replica, they both could be trying to create the private key in the database at the same time
|
||||||
// In this case, only one of the replicas will succeed; the other one(s) will return an error here, which will cascade down and cause the replica(s) to crash and be restarted (at that point they'll load the then-existing key from the database)
|
// In this case, only one of the replicas will succeed; the other one(s) will return an error here, which will cascade down and cause the replica(s) to crash and be restarted (at that point they'll load the then-existing key from the database)
|
||||||
|
|||||||
@@ -1,5 +1,16 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
|
// Ptr returns a pointer to the given value.
|
||||||
func Ptr[T any](v T) *T {
|
func Ptr[T any](v T) *T {
|
||||||
return &v
|
return &v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PtrOrNil returns a pointer to v if v is not the zero value of its type,
|
||||||
|
// otherwise it returns nil.
|
||||||
|
func PtrOrNil[T comparable](v T) *T {
|
||||||
|
var zero T
|
||||||
|
if v == zero {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &v
|
||||||
|
}
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ func NewDatabaseForTest(t *testing.T) *gorm.DB {
|
|||||||
|
|
||||||
// Connect to a new in-memory SQL database
|
// Connect to a new in-memory SQL database
|
||||||
db, err := gorm.Open(
|
db, err := gorm.Open(
|
||||||
sqlite.Open("file:"+dbName+"?mode=memory&cache=shared"),
|
sqlite.Open("file:"+dbName+"?mode=memory"),
|
||||||
&gorm.Config{
|
&gorm.Config{
|
||||||
TranslateError: true,
|
TranslateError: true,
|
||||||
Logger: logger.New(
|
Logger: logger.New(
|
||||||
@@ -52,9 +52,14 @@ func NewDatabaseForTest(t *testing.T) *gorm.DB {
|
|||||||
})
|
})
|
||||||
require.NoError(t, err, "Failed to connect to test database")
|
require.NoError(t, err, "Failed to connect to test database")
|
||||||
|
|
||||||
// Perform migrations with the embedded migrations
|
|
||||||
sqlDB, err := db.DB()
|
sqlDB, err := db.DB()
|
||||||
require.NoError(t, err, "Failed to get sql.DB")
|
require.NoError(t, err, "Failed to get sql.DB")
|
||||||
|
|
||||||
|
// For in-memory SQLite databases, we must limit to 1 open connection at the same time, or they won't see the whole data
|
||||||
|
// The other workaround, of using shared caches, doesn't work well with multiple write transactions trying to happen at once
|
||||||
|
sqlDB.SetMaxOpenConns(1)
|
||||||
|
|
||||||
|
// Perform migrations with the embedded migrations
|
||||||
driver, err := sqliteMigrate.WithInstance(sqlDB, &sqliteMigrate.Config{
|
driver, err := sqliteMigrate.WithInstance(sqlDB, &sqliteMigrate.Config{
|
||||||
NoTxWrap: true,
|
NoTxWrap: true,
|
||||||
})
|
})
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1,17 +1,3 @@
|
|||||||
{{ define "base" }}
|
{{define "root"}}<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html dir="ltr" lang="en"><head><link rel="preload" as="image" href="{{.LogoURL}}"/><meta content="text/html; charset=UTF-8" http-equiv="Content-Type"/><meta name="x-apple-disable-message-reformatting"/></head><body style="padding:50px;background-color:#FBFBFB;font-family:Arial, sans-serif"><!--$--><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="max-width:37.5em;width:500px;margin:0 auto"><tbody><tr style="width:100%"><td><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody><tr><td><table align="left" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="margin-bottom:16px"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column" style="width:50px">
|
||||||
<div class="header">
|
<img alt="{{.AppName}}" height="32" src="{{.LogoURL}}" style="display:block;outline:none;border:none;text-decoration:none;width:32px;height:32px;vertical-align:middle" width="32"/></td><td data-id="__react-email-column"><p style="font-size:23px;line-height:24px;font-weight:bold;margin:0;padding:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{.AppName}}</p></td></tr></tbody></table></td></tr></tbody></table><div style="background-color:white;padding:24px;border-radius:10px;box-shadow:0 1px 4px 0px rgba(0, 0, 0, 0.1)"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column"><h1 style="font-size:20px;font-weight:bold;margin:0">API Key Expiring Soon</h1></td><td align="right" data-id="__react-email-column">
|
||||||
<div class="logo">
|
<p style="font-size:12px;line-height:24px;background-color:#ffd966;color:#7f6000;padding:1px 12px;border-radius:50px;display:inline-block;margin:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">Warning</p></td></tr></tbody></table><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Hello <!-- -->{{.Data.Name}}<!-- -->, <br/>This is a reminder that your API key <strong>{{.Data.APIKeyName}}</strong> <!-- -->will expire on <strong>{{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}</strong>.</p><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Please generate a new API key if you need continued access.</p></div></td></tr></tbody></table><!--7--><!--/$--></body></html>{{end}}
|
||||||
<img src="{{ .LogoURL }}" alt="{{ .AppName }}" width="32" height="32" style="width: 32px; height: 32px; max-width: 32px;"/>
|
|
||||||
<h1>{{ .AppName }}</h1>
|
|
||||||
</div>
|
|
||||||
<div class="warning">Warning</div>
|
|
||||||
</div>
|
|
||||||
<div class="content">
|
|
||||||
<h2>API Key Expiring Soon</h2>
|
|
||||||
<p>
|
|
||||||
Hello {{ .Data.Name }},<br/><br/>
|
|
||||||
This is a reminder that your API key <strong>{{ .Data.ApiKeyName }}</strong> will expire on <strong>{{ .Data.ExpiresAt.Format "2006-01-02 15:04:05 MST" }}</strong>.<br/><br/>
|
|
||||||
Please generate a new API key if you need continued access.
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
{{ end }}
|
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
{{ define "base" -}}
|
{{define "root"}}{{.AppName}}
|
||||||
API Key Expiring Soon
|
|
||||||
====================
|
|
||||||
|
|
||||||
Hello {{ .Data.Name }},
|
|
||||||
|
|
||||||
This is a reminder that your API key "{{ .Data.ApiKeyName }}" will expire on {{ .Data.ExpiresAt.Format "2006-01-02 15:04:05 MST" }}.
|
API KEY EXPIRING SOON
|
||||||
|
|
||||||
Please generate a new API key if you need continued access.
|
Warning
|
||||||
{{ end -}}
|
|
||||||
|
Hello {{.Data.Name}},
|
||||||
|
This is a reminder that your API key {{.Data.APIKeyName}} will expire on
|
||||||
|
{{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}.
|
||||||
|
|
||||||
|
Please generate a new API key if you need continued access.{{end}}
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
{{ define "root" }}
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
||||||
{{ template "style" . }}
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="container">
|
|
||||||
{{ template "base" . }}
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
{{ end }}
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
{{- define "root" -}}
|
|
||||||
{{- template "base" . -}}
|
|
||||||
{{- end }}
|
|
||||||
|
|
||||||
|
|
||||||
--
|
|
||||||
This is automatically sent email from {{.AppName}}.
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
{{ define "style" }}
|
|
||||||
<style>
|
|
||||||
/* Reset styles for email clients */
|
|
||||||
body, table, td, p, a {
|
|
||||||
margin: 0;
|
|
||||||
padding: 0;
|
|
||||||
border: 0;
|
|
||||||
font-size: 100%;
|
|
||||||
font-family: Arial, sans-serif;
|
|
||||||
line-height: 1.5;
|
|
||||||
}
|
|
||||||
body {
|
|
||||||
background-color: #f0f0f0;
|
|
||||||
color: #333;
|
|
||||||
}
|
|
||||||
.container {
|
|
||||||
width: 100%;
|
|
||||||
max-width: 600px;
|
|
||||||
margin: 40px auto;
|
|
||||||
background-color: #fff;
|
|
||||||
border-radius: 10px;
|
|
||||||
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
|
|
||||||
padding: 32px;
|
|
||||||
}
|
|
||||||
.header {
|
|
||||||
display: flex;
|
|
||||||
margin-bottom: 24px;
|
|
||||||
}
|
|
||||||
.header .logo img {
|
|
||||||
width: 32px;
|
|
||||||
height: 32px;
|
|
||||||
vertical-align: middle;
|
|
||||||
}
|
|
||||||
.header h1 {
|
|
||||||
font-size: 1.5rem;
|
|
||||||
font-weight: bold;
|
|
||||||
display: inline-block;
|
|
||||||
vertical-align: middle;
|
|
||||||
margin-left: 8px;
|
|
||||||
}
|
|
||||||
.warning {
|
|
||||||
background-color: #ffd966;
|
|
||||||
color: #7f6000;
|
|
||||||
padding: 4px 12px;
|
|
||||||
border-radius: 50px;
|
|
||||||
font-size: 0.875rem;
|
|
||||||
margin: auto 0 auto auto;
|
|
||||||
}
|
|
||||||
.content {
|
|
||||||
background-color: #fafafa;
|
|
||||||
padding: 24px;
|
|
||||||
border-radius: 10px;
|
|
||||||
}
|
|
||||||
.content h2 {
|
|
||||||
font-size: 1.25rem;
|
|
||||||
font-weight: bold;
|
|
||||||
margin-bottom: 16px;
|
|
||||||
}
|
|
||||||
.grid {
|
|
||||||
width: 100%;
|
|
||||||
margin-bottom: 16px;
|
|
||||||
}
|
|
||||||
.grid td {
|
|
||||||
width: 50%;
|
|
||||||
padding-bottom: 8px;
|
|
||||||
vertical-align: top;
|
|
||||||
}
|
|
||||||
.label {
|
|
||||||
color: #888;
|
|
||||||
font-size: 0.875rem;
|
|
||||||
}
|
|
||||||
.message {
|
|
||||||
font-size: 1rem;
|
|
||||||
line-height: 1.5;
|
|
||||||
margin-top: 16px;
|
|
||||||
}
|
|
||||||
.button {
|
|
||||||
background-color: #000000;
|
|
||||||
color: #ffffff;
|
|
||||||
padding: 0.7rem 1.5rem;
|
|
||||||
text-decoration: none;
|
|
||||||
border-radius: 4px;
|
|
||||||
font-size: 1rem;
|
|
||||||
font-weight: 500;
|
|
||||||
display: inline-block;
|
|
||||||
margin-top: 24px;
|
|
||||||
}
|
|
||||||
.button-container {
|
|
||||||
text-align: center;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
{{ end }}
|
|
||||||
@@ -1,40 +1,5 @@
|
|||||||
{{ define "base" }}
|
{{define "root"}}<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html dir="ltr" lang="en"><head><link rel="preload" as="image" href="{{.LogoURL}}"/><meta content="text/html; charset=UTF-8" http-equiv="Content-Type"/><meta name="x-apple-disable-message-reformatting"/></head><body style="padding:50px;background-color:#FBFBFB;font-family:Arial, sans-serif"><!--$--><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="max-width:37.5em;width:500px;margin:0 auto"><tbody><tr style="width:100%"><td><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody><tr><td><table align="left" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="margin-bottom:16px"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column" style="width:50px">
|
||||||
<div class="header">
|
<img alt="{{.AppName}}" height="32" src="{{.LogoURL}}" style="display:block;outline:none;border:none;text-decoration:none;width:32px;height:32px;vertical-align:middle" width="32"/></td><td data-id="__react-email-column"><p style="font-size:23px;line-height:24px;font-weight:bold;margin:0;padding:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{.AppName}}</p></td></tr></tbody></table></td></tr></tbody></table><div style="background-color:white;padding:24px;border-radius:10px;box-shadow:0 1px 4px 0px rgba(0, 0, 0, 0.1)"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column"><h1 style="font-size:20px;font-weight:bold;margin:0">New Sign-In Detected</h1></td><td align="right" data-id="__react-email-column">
|
||||||
<div class="logo">
|
<p style="font-size:12px;line-height:24px;background-color:#ffd966;color:#7f6000;padding:1px 12px;border-radius:50px;display:inline-block;margin:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">Warning</p></td></tr></tbody></table><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Your <!-- -->{{.AppName}}<!-- --> account was recently accessed from a new IP address or browser. If you recognize this activity, no further action is required.</p><h4 style="font-size:1rem;font-weight:bold;margin:30px 0 10px 0">Details</h4><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column" style="width:225px"><p style="font-size:12px;line-height:24px;margin:0;color:gray;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">Approximate Location</p>
|
||||||
<img src="{{ .LogoURL }}" alt="{{ .AppName }}" width="32" height="32" style="width: 32px; height: 32px; max-width: 32px;"/>
|
<p style="font-size:14px;line-height:24px;margin:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{if and .Data.City .Data.Country}}{{.Data.City}}, {{.Data.Country}}{{else if .Data.Country}}{{.Data.Country}}{{else}}Unknown{{end}}</p></td><td data-id="__react-email-column" style="width:225px"><p style="font-size:12px;line-height:24px;margin:0;color:gray;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">IP Address</p><p style="font-size:14px;line-height:24px;margin:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{.Data.IPAddress}}</p></td></tr></tbody></table><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="margin-top:10px"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column" style="width:225px"><p style="font-size:12px;line-height:24px;margin:0;color:gray;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">Device</p>
|
||||||
<h1>{{ .AppName }}</h1>
|
<p style="font-size:14px;line-height:24px;margin:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{.Data.Device}}</p></td><td data-id="__react-email-column" style="width:225px"><p style="font-size:12px;line-height:24px;margin:0;color:gray;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">Sign-In Time</p><p style="font-size:14px;line-height:24px;margin:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{.Data.DateTime.Format "January 2, 2006 at 3:04 PM MST"}}</p></td></tr></tbody></table></div></td></tr></tbody></table><!--7--><!--/$--></body></html>{{end}}
|
||||||
</div>
|
|
||||||
<div class="warning">Warning</div>
|
|
||||||
</div>
|
|
||||||
<div class="content">
|
|
||||||
<h2>New Sign-In Detected</h2>
|
|
||||||
<table class="grid">
|
|
||||||
<tr>
|
|
||||||
{{ if and .Data.City .Data.Country }}
|
|
||||||
<td>
|
|
||||||
<p class="label">Approximate Location</p>
|
|
||||||
<p>{{ .Data.City }}, {{ .Data.Country }}</p>
|
|
||||||
</td>
|
|
||||||
{{ end }}
|
|
||||||
<td>
|
|
||||||
<p class="label">IP Address</p>
|
|
||||||
<p>{{ .Data.IPAddress }}</p>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<p class="label">Device</p>
|
|
||||||
<p>{{ .Data.Device }}</p>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<p class="label">Sign-In Time</p>
|
|
||||||
<p>{{ .Data.DateTime.Format "2006-01-02 15:04:05 UTC" }}</p>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
<p class="message">
|
|
||||||
This sign-in was detected from a new device or location. If you recognize this activity, you can
|
|
||||||
safely ignore this message. If not, please review your account and security settings.
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
{{ end -}}
|
|
||||||
@@ -1,15 +1,28 @@
|
|||||||
{{ define "base" -}}
|
{{define "root"}}{{.AppName}}
|
||||||
New Sign-In Detected
|
|
||||||
====================
|
|
||||||
|
|
||||||
{{ if and .Data.City .Data.Country }}
|
|
||||||
Approximate Location: {{ .Data.City }}, {{ .Data.Country }}
|
|
||||||
{{ end }}
|
|
||||||
IP Address: {{ .Data.IPAddress }}
|
|
||||||
Device: {{ .Data.Device }}
|
|
||||||
Time: {{ .Data.DateTime.Format "2006-01-02 15:04:05 UTC"}}
|
|
||||||
|
|
||||||
This sign-in was detected from a new device or location. If you recognize
|
NEW SIGN-IN DETECTED
|
||||||
this activity, you can safely ignore this message. If not, please review
|
|
||||||
your account and security settings.
|
Warning
|
||||||
{{ end -}}
|
|
||||||
|
Your {{.AppName}} account was recently accessed from a new IP address or
|
||||||
|
browser. If you recognize this activity, no further action is required.
|
||||||
|
|
||||||
|
DETAILS
|
||||||
|
|
||||||
|
Approximate Location
|
||||||
|
|
||||||
|
{{if and .Data.City .Data.Country}}{{.Data.City}}, {{.Data.Country}}{{else if
|
||||||
|
.Data.Country}}{{.Data.Country}}{{else}}Unknown{{end}}
|
||||||
|
|
||||||
|
IP Address
|
||||||
|
|
||||||
|
{{.Data.IPAddress}}
|
||||||
|
|
||||||
|
Device
|
||||||
|
|
||||||
|
{{.Data.Device}}
|
||||||
|
|
||||||
|
Sign-In Time
|
||||||
|
|
||||||
|
{{.Data.DateTime.Format "January 2, 2006 at 3:04 PM MST"}}{{end}}
|
||||||
@@ -1,17 +1,4 @@
|
|||||||
{{ define "base" }}
|
{{define "root"}}<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html dir="ltr" lang="en"><head><link rel="preload" as="image" href="{{.LogoURL}}"/><meta content="text/html; charset=UTF-8" http-equiv="Content-Type"/><meta name="x-apple-disable-message-reformatting"/></head><body style="padding:50px;background-color:#FBFBFB;font-family:Arial, sans-serif"><!--$--><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="max-width:37.5em;width:500px;margin:0 auto"><tbody><tr style="width:100%"><td><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody><tr><td><table align="left" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="margin-bottom:16px"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column" style="width:50px">
|
||||||
<div class="header">
|
<img alt="{{.AppName}}" height="32" src="{{.LogoURL}}" style="display:block;outline:none;border:none;text-decoration:none;width:32px;height:32px;vertical-align:middle" width="32"/></td><td data-id="__react-email-column"><p style="font-size:23px;line-height:24px;font-weight:bold;margin:0;padding:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{.AppName}}</p></td></tr></tbody></table></td></tr></tbody></table><div style="background-color:white;padding:24px;border-radius:10px;box-shadow:0 1px 4px 0px rgba(0, 0, 0, 0.1)"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column"><h1 style="font-size:20px;font-weight:bold;margin:0">Your Login Code</h1></td><td align="right" data-id="__react-email-column"></td></tr></tbody></table><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Click the button below to sign in to <!-- -->
|
||||||
<div class="logo">
|
{{.AppName}}<!-- --> with a login code.<br/>Or visit<!-- --> <a href="{{.Data.LoginLink}}" style="color:#000;text-decoration-line:none;text-decoration:underline;font-family:Arial, sans-serif" target="_blank">{{.Data.LoginLink}}</a> <!-- -->and enter the code <strong>{{.Data.Code}}</strong>.<br/><br/>This code expires in <!-- -->{{.Data.ExpirationString}}<!-- -->.</p><div style="text-align:center"><a href="{{.Data.LoginLinkWithCode}}" style="line-height:100%;text-decoration:none;display:inline-block;max-width:100%;mso-padding-alt:0px;background-color:#000000;color:#ffffff;padding:12px 24px;border-radius:4px;font-size:15px;font-weight:500;cursor:pointer;margin-top:10px;padding-top:12px;padding-right:24px;padding-bottom:12px;padding-left:24px" target="_blank"><span><!--[if mso]><i style="mso-font-width:400%;mso-text-raise:18" hidden>   </i><![endif]--></span><span style="max-width:100%;display:inline-block;line-height:120%;mso-padding-alt:0px;mso-text-raise:9px">
|
||||||
<img src="{{ .LogoURL }}" alt="{{ .AppName }}" width="32" height="32" style="width: 32px; height: 32px; max-width: 32px;"/>
|
Sign In</span><span><!--[if mso]><i style="mso-font-width:400%" hidden>   ​</i><![endif]--></span></a></div></div></td></tr></tbody></table><!--7--><!--/$--></body></html>{{end}}
|
||||||
<h1>{{ .AppName }}</h1>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="content">
|
|
||||||
<h2>Login Code</h2>
|
|
||||||
<p class="message">
|
|
||||||
Click the button below to sign in to {{ .AppName }} with a login code.</br>Or visit <a href="{{ .Data.LoginLink }}">{{ .Data.LoginLink }}</a> and enter the code <strong>{{ .Data.Code }}</strong>.</br></br>This code expires in {{.Data.ExpirationString}}.
|
|
||||||
</p>
|
|
||||||
<div class="button-container">
|
|
||||||
<a class="button" href="{{ .Data.LoginLinkWithCode }}" class="button">Sign In</a>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{{ end -}}
|
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
{{ define "base" -}}
|
{{define "root"}}{{.AppName}}
|
||||||
Login Code
|
|
||||||
====================
|
|
||||||
|
|
||||||
Click the link below to sign in to {{ .AppName }} with a login code. This code expires in {{.Data.ExpirationString}}.
|
|
||||||
|
|
||||||
{{ .Data.LoginLinkWithCode }}
|
YOUR LOGIN CODE
|
||||||
|
|
||||||
Or visit {{ .Data.LoginLink }} and enter the the code "{{ .Data.Code }}".
|
Click the button below to sign in to {{.AppName}} with a login code.
|
||||||
{{ end -}}
|
Or visit {{.Data.LoginLink}} {{.Data.LoginLink}} and enter the code
|
||||||
|
{{.Data.Code}}.
|
||||||
|
|
||||||
|
This code expires in {{.Data.ExpirationString}}.
|
||||||
|
|
||||||
|
Sign In {{.Data.LoginLinkWithCode}}{{end}}
|
||||||
@@ -1,11 +1,3 @@
|
|||||||
{{ define "base" -}}
|
{{define "root"}}<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html dir="ltr" lang="en"><head><link rel="preload" as="image" href="{{.LogoURL}}"/><meta content="text/html; charset=UTF-8" http-equiv="Content-Type"/><meta name="x-apple-disable-message-reformatting"/></head><body style="padding:50px;background-color:#FBFBFB;font-family:Arial, sans-serif"><!--$--><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="max-width:37.5em;width:500px;margin:0 auto"><tbody><tr style="width:100%"><td><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody><tr><td><table align="left" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation" style="margin-bottom:16px"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column" style="width:50px">
|
||||||
<div class="header">
|
<img alt="{{.AppName}}" height="32" src="{{.LogoURL}}" style="display:block;outline:none;border:none;text-decoration:none;width:32px;height:32px;vertical-align:middle" width="32"/></td><td data-id="__react-email-column"><p style="font-size:23px;line-height:24px;font-weight:bold;margin:0;padding:0;margin-top:0;margin-bottom:0;margin-left:0;margin-right:0">{{.AppName}}</p></td></tr></tbody></table></td></tr></tbody></table><div style="background-color:white;padding:24px;border-radius:10px;box-shadow:0 1px 4px 0px rgba(0, 0, 0, 0.1)"><table align="center" width="100%" border="0" cellPadding="0" cellSpacing="0" role="presentation"><tbody style="width:100%"><tr style="width:100%"><td data-id="__react-email-column"><h1 style="font-size:20px;font-weight:bold;margin:0">Test Email</h1></td><td align="right" data-id="__react-email-column"></td></tr></tbody></table><p style="font-size:14px;line-height:24px;margin-top:16px;margin-bottom:16px">Your email setup is working correctly!</p></div></td>
|
||||||
<div class="logo">
|
</tr></tbody></table><!--7--><!--/$--></body></html>{{end}}
|
||||||
<img src="{{ .LogoURL }}" alt="{{ .AppName }}"/>
|
|
||||||
<h1>{{ .AppName }}</h1>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div class="content">
|
|
||||||
<p>This is a test email.</p>
|
|
||||||
</div>
|
|
||||||
{{ end -}}
|
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
{{ define "base" -}}
|
{{define "root"}}{{.AppName}}
|
||||||
This is a test email.
|
|
||||||
{{ end -}}
|
|
||||||
|
TEST EMAIL
|
||||||
|
|
||||||
|
Your email setup is working correctly!{{end}}
|
||||||
@@ -4,5 +4,5 @@ import "embed"
|
|||||||
|
|
||||||
// Embedded file systems for the project
|
// Embedded file systems for the project
|
||||||
|
|
||||||
//go:embed email-templates images migrations fonts aaguids.json
|
//go:embed email-templates/*.tmpl images migrations fonts aaguids.json
|
||||||
var FS embed.FS
|
var FS embed.FS
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 3.7 MiB |
BIN
backend/resources/images/background.webp
Normal file
BIN
backend/resources/images/background.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 291 KiB |
@@ -0,0 +1,3 @@
|
|||||||
|
ALTER TABLE users DROP COLUMN display_name;
|
||||||
|
|
||||||
|
ALTER TABLE users ALTER COLUMN username TYPE TEXT;
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
ALTER TABLE users ADD COLUMN display_name TEXT;
|
||||||
|
UPDATE users SET display_name = trim(coalesce(first_name,'') || ' ' || coalesce(last_name,''));
|
||||||
|
ALTER TABLE users ALTER COLUMN display_name SET NOT NULL;
|
||||||
|
|
||||||
|
CREATE EXTENSION IF NOT EXISTS citext;
|
||||||
|
ALTER TABLE users ALTER COLUMN username TYPE CITEXT COLLATE "C";
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
-- No-op because email was optional before the migration
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
ALTER TABLE users ALTER COLUMN email DROP NOT NULL;
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
BEGIN;
|
||||||
|
ALTER TABLE users DROP COLUMN display_name;
|
||||||
|
COMMIT;
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
PRAGMA foreign_keys = OFF;
|
||||||
|
BEGIN;
|
||||||
|
|
||||||
|
CREATE TABLE users_new
|
||||||
|
(
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
created_at DATETIME,
|
||||||
|
username TEXT NOT NULL COLLATE NOCASE UNIQUE,
|
||||||
|
email TEXT NOT NULL UNIQUE,
|
||||||
|
first_name TEXT,
|
||||||
|
last_name TEXT NOT NULL,
|
||||||
|
display_name TEXT NOT NULL,
|
||||||
|
is_admin NUMERIC NOT NULL DEFAULT FALSE,
|
||||||
|
ldap_id TEXT,
|
||||||
|
locale TEXT,
|
||||||
|
disabled NUMERIC NOT NULL DEFAULT FALSE
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO users_new (id, created_at, username, email, first_name, last_name, display_name, is_admin, ldap_id, locale,
|
||||||
|
disabled)
|
||||||
|
SELECT id,
|
||||||
|
created_at,
|
||||||
|
username,
|
||||||
|
email,
|
||||||
|
first_name,
|
||||||
|
COALESCE(last_name, ''),
|
||||||
|
TRIM(COALESCE(first_name, '') || ' ' || COALESCE(last_name, '')),
|
||||||
|
is_admin,
|
||||||
|
ldap_id,
|
||||||
|
locale,
|
||||||
|
disabled
|
||||||
|
FROM users;
|
||||||
|
|
||||||
|
DROP TABLE users;
|
||||||
|
|
||||||
|
ALTER TABLE users_new
|
||||||
|
RENAME TO users;
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX users_ldap_id ON users (ldap_id);
|
||||||
|
|
||||||
|
COMMIT;
|
||||||
|
PRAGMA foreign_keys = ON;
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
-- No-op because email was optional before the migration
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
PRAGMA foreign_keys = OFF;
|
||||||
|
BEGIN;
|
||||||
|
|
||||||
|
CREATE TABLE users_new
|
||||||
|
(
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
created_at DATETIME,
|
||||||
|
username TEXT NOT NULL COLLATE NOCASE UNIQUE,
|
||||||
|
email TEXT UNIQUE,
|
||||||
|
first_name TEXT,
|
||||||
|
last_name TEXT NOT NULL,
|
||||||
|
display_name TEXT NOT NULL,
|
||||||
|
is_admin NUMERIC NOT NULL DEFAULT FALSE,
|
||||||
|
ldap_id TEXT,
|
||||||
|
locale TEXT,
|
||||||
|
disabled NUMERIC NOT NULL DEFAULT FALSE
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO users_new (id, created_at, username, email, first_name, last_name, display_name, is_admin, ldap_id, locale,
|
||||||
|
disabled)
|
||||||
|
SELECT id,
|
||||||
|
created_at,
|
||||||
|
username,
|
||||||
|
email,
|
||||||
|
first_name,
|
||||||
|
last_name,
|
||||||
|
display_name,
|
||||||
|
is_admin,
|
||||||
|
ldap_id,
|
||||||
|
locale,
|
||||||
|
disabled
|
||||||
|
FROM users;
|
||||||
|
|
||||||
|
DROP TABLE users;
|
||||||
|
|
||||||
|
ALTER TABLE users_new
|
||||||
|
RENAME TO users;
|
||||||
|
|
||||||
|
COMMIT;
|
||||||
|
PRAGMA foreign_keys = ON;
|
||||||
45
cliff.toml
Normal file
45
cliff.toml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# git-cliff ~ configuration file
|
||||||
|
# https://git-cliff.org/docs/configuration
|
||||||
|
|
||||||
|
[remote.github]
|
||||||
|
owner = "pocket-id"
|
||||||
|
repo = "pocket-id"
|
||||||
|
|
||||||
|
[git]
|
||||||
|
conventional_commits = true
|
||||||
|
filter_unconventional = true
|
||||||
|
commit_preprocessors = [{ pattern = '\((\w+\s)?#([0-9]+)\)', replace = "" }]
|
||||||
|
commit_parsers = [
|
||||||
|
{ message = "^feat", group = "Features" },
|
||||||
|
{ message = "^fix", group = "Bug Fixes" },
|
||||||
|
{ message = "^docs", group = "Documentation" },
|
||||||
|
{ message = "^perf", group = "Performance Improvements" },
|
||||||
|
{ message = "^release", skip = true },
|
||||||
|
{ message = "update translations via Crowdin", skip = true },
|
||||||
|
{ message = ".*", group = "Other", default_scope = "other"},
|
||||||
|
]
|
||||||
|
filter_commits = false
|
||||||
|
|
||||||
|
[changelog]
|
||||||
|
trim = true
|
||||||
|
body = """
|
||||||
|
## {{ version | default(value="Unknown Version") }}
|
||||||
|
{% for group, commits in commits | group_by(attribute="group") %}
|
||||||
|
### {{ group | title }}
|
||||||
|
{% for commit in commits %}
|
||||||
|
- {{ commit.message | trim }} \
|
||||||
|
{%- if commit.remote.pr_number %} ([#{{ commit.remote.pr_number }}]({{ self::remote_url() }}/pull/{{ commit.remote.pr_number }}) by @{{ commit.remote.username | default(value=commit.author.name) }}){%- else %} ([{{ commit.id | truncate(length=7, end="") }}]({{ self::remote_url() }}/commit/{{ commit.id }}) by @{{ commit.remote.username | default(value=commit.author.name) }}){%- endif -%}
|
||||||
|
{% endfor %}
|
||||||
|
{% endfor %}
|
||||||
|
{% if version -%}
|
||||||
|
{% if previous.version -%}
|
||||||
|
**Full Changelog**: {{ self::remote_url() }}/compare/{{ previous.version }}...{{ version }}
|
||||||
|
{% endif %}
|
||||||
|
{% else -%}
|
||||||
|
{% raw %}\n{% endraw %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{%- macro remote_url() -%}
|
||||||
|
https://github.com/{{ remote.github.owner }}/{{ remote.github.repo }}
|
||||||
|
{%- endmacro -%}
|
||||||
|
"""
|
||||||
115
email-templates/build.ts
Normal file
115
email-templates/build.ts
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import { render } from "@react-email/components";
|
||||||
|
import * as fs from "node:fs";
|
||||||
|
import * as path from "node:path";
|
||||||
|
|
||||||
|
const outputDir = "../backend/resources/email-templates";
|
||||||
|
|
||||||
|
if (!fs.existsSync(outputDir)) {
|
||||||
|
fs.mkdirSync(outputDir, { recursive: true });
|
||||||
|
}
|
||||||
|
|
||||||
|
function getTemplateName(filename: string): string {
|
||||||
|
return filename.replace(".tsx", "");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tag-aware wrapping:
|
||||||
|
* - Prefer breaking immediately after the last '>' within maxLen.
|
||||||
|
* - Never break at spaces.
|
||||||
|
* - If no '>' exists in the window, hard-break at maxLen.
|
||||||
|
*/
|
||||||
|
function tagAwareWrap(input: string, maxLen: number): string {
|
||||||
|
const out: string[] = [];
|
||||||
|
|
||||||
|
for (const originalLine of input.split(/\r?\n/)) {
|
||||||
|
let line = originalLine;
|
||||||
|
while (line.length > maxLen) {
|
||||||
|
let breakPos = line.lastIndexOf(">", maxLen);
|
||||||
|
|
||||||
|
// If '>' happens to be exactly at maxLen, break after it
|
||||||
|
if (breakPos === maxLen) breakPos = maxLen;
|
||||||
|
|
||||||
|
// If we found a '>' before the limit, break right after it
|
||||||
|
if (breakPos > -1 && breakPos < maxLen) {
|
||||||
|
out.push(line.slice(0, breakPos + 1));
|
||||||
|
line = line.slice(breakPos + 1);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// No suitable tag end found—hard break
|
||||||
|
out.push(line.slice(0, maxLen));
|
||||||
|
line = line.slice(maxLen);
|
||||||
|
}
|
||||||
|
out.push(line);
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.join("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
async function buildTemplateFile(
|
||||||
|
Component: any,
|
||||||
|
templateName: string,
|
||||||
|
isPlainText: boolean
|
||||||
|
) {
|
||||||
|
const rendered = await render(Component(Component.TemplateProps), {
|
||||||
|
plainText: isPlainText,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Normalize quotes
|
||||||
|
const normalized = rendered.replace(/"/g, '"');
|
||||||
|
|
||||||
|
// Enforce line length: prefer tag boundaries, never spaces
|
||||||
|
const maxLen = isPlainText ? 78 : 998; // RFC-safe
|
||||||
|
const safe = tagAwareWrap(normalized, maxLen);
|
||||||
|
|
||||||
|
const goTemplate = `{{define "root"}}${safe}{{end}}`;
|
||||||
|
const suffix = isPlainText ? "_text.tmpl" : "_html.tmpl";
|
||||||
|
const templatePath = path.join(outputDir, `${templateName}${suffix}`);
|
||||||
|
|
||||||
|
fs.writeFileSync(templatePath, goTemplate);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function discoverAndBuildTemplates() {
|
||||||
|
console.log("Discovering and building email templates...");
|
||||||
|
|
||||||
|
const emailsDir = "./emails";
|
||||||
|
const files = fs.readdirSync(emailsDir);
|
||||||
|
|
||||||
|
for (const file of files) {
|
||||||
|
if (!file.endsWith(".tsx")) continue;
|
||||||
|
|
||||||
|
const templateName = getTemplateName(file);
|
||||||
|
const modulePath = `./${emailsDir}/${file}`;
|
||||||
|
|
||||||
|
console.log(`Building ${templateName}...`);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const module = await import(modulePath);
|
||||||
|
const Component = module.default || module[Object.keys(module)[0]];
|
||||||
|
|
||||||
|
if (!Component) {
|
||||||
|
console.error(`✗ No component found in ${file}`);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Component.TemplateProps) {
|
||||||
|
console.error(`✗ No TemplateProps found in ${file}`);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
await buildTemplateFile(Component, templateName, false); // HTML
|
||||||
|
await buildTemplateFile(Component, templateName, true); // Text
|
||||||
|
|
||||||
|
console.log(`✓ Built ${templateName}`);
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`✗ Error building ${templateName}:`, error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
await discoverAndBuildTemplates();
|
||||||
|
console.log("All templates built successfully!");
|
||||||
|
}
|
||||||
|
|
||||||
|
main().catch(console.error);
|
||||||
81
email-templates/components/base-template.tsx
Normal file
81
email-templates/components/base-template.tsx
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import {
|
||||||
|
Body,
|
||||||
|
Column,
|
||||||
|
Container,
|
||||||
|
Head,
|
||||||
|
Html,
|
||||||
|
Img,
|
||||||
|
Row,
|
||||||
|
Section,
|
||||||
|
Text,
|
||||||
|
} from "@react-email/components";
|
||||||
|
|
||||||
|
interface BaseTemplateProps {
|
||||||
|
logoURL?: string;
|
||||||
|
appName: string;
|
||||||
|
children: React.ReactNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const BaseTemplate = ({
|
||||||
|
logoURL,
|
||||||
|
appName,
|
||||||
|
children,
|
||||||
|
}: BaseTemplateProps) => {
|
||||||
|
return (
|
||||||
|
<Html>
|
||||||
|
<Head />
|
||||||
|
<Body style={mainStyle}>
|
||||||
|
<Container style={{ width: "500px", margin: "0 auto" }}>
|
||||||
|
<Section>
|
||||||
|
<Row
|
||||||
|
align="left"
|
||||||
|
style={{
|
||||||
|
marginBottom: "16px",
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Column style={{ width: "50px" }}>
|
||||||
|
<Img
|
||||||
|
src={logoURL}
|
||||||
|
width="32"
|
||||||
|
height="32"
|
||||||
|
alt={appName}
|
||||||
|
style={logoStyle}
|
||||||
|
/>
|
||||||
|
</Column>
|
||||||
|
<Column>
|
||||||
|
<Text style={titleStyle}>{appName}</Text>
|
||||||
|
</Column>
|
||||||
|
</Row>
|
||||||
|
</Section>
|
||||||
|
<div style={content}>{children}</div>
|
||||||
|
</Container>
|
||||||
|
</Body>
|
||||||
|
</Html>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const mainStyle = {
|
||||||
|
padding: "50px",
|
||||||
|
backgroundColor: "#FBFBFB",
|
||||||
|
fontFamily: "Arial, sans-serif",
|
||||||
|
};
|
||||||
|
|
||||||
|
const logoStyle = {
|
||||||
|
width: "32px",
|
||||||
|
height: "32px",
|
||||||
|
verticalAlign: "middle",
|
||||||
|
};
|
||||||
|
|
||||||
|
const titleStyle = {
|
||||||
|
fontSize: "23px",
|
||||||
|
fontWeight: "bold",
|
||||||
|
margin: "0",
|
||||||
|
padding: "0",
|
||||||
|
};
|
||||||
|
|
||||||
|
const content = {
|
||||||
|
backgroundColor: "white",
|
||||||
|
padding: "24px",
|
||||||
|
borderRadius: "10px",
|
||||||
|
boxShadow: "0 1px 4px 0px rgba(0, 0, 0, 0.1)",
|
||||||
|
};
|
||||||
33
email-templates/components/button.tsx
Normal file
33
email-templates/components/button.tsx
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import { Button as EmailButton } from "@react-email/components";
|
||||||
|
|
||||||
|
interface ButtonProps {
|
||||||
|
href: string;
|
||||||
|
children: React.ReactNode;
|
||||||
|
style?: React.CSSProperties;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const Button = ({ href, children, style = {} }: ButtonProps) => {
|
||||||
|
const buttonStyle = {
|
||||||
|
backgroundColor: "#000000",
|
||||||
|
color: "#ffffff",
|
||||||
|
padding: "12px 24px",
|
||||||
|
borderRadius: "4px",
|
||||||
|
fontSize: "15px",
|
||||||
|
fontWeight: "500",
|
||||||
|
cursor: "pointer",
|
||||||
|
marginTop: "10px",
|
||||||
|
...style,
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div style={buttonContainer}>
|
||||||
|
<EmailButton style={buttonStyle} href={href}>
|
||||||
|
{children}
|
||||||
|
</EmailButton>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const buttonContainer = {
|
||||||
|
textAlign: "center" as const,
|
||||||
|
};
|
||||||
38
email-templates/components/card-header.tsx
Normal file
38
email-templates/components/card-header.tsx
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import { Column, Heading, Row, Text } from "@react-email/components";
|
||||||
|
|
||||||
|
export default function CardHeader({
|
||||||
|
title,
|
||||||
|
warning,
|
||||||
|
}: {
|
||||||
|
title: string;
|
||||||
|
warning?: boolean;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Row>
|
||||||
|
<Column>
|
||||||
|
<Heading as="h1" style={titleStyle}>
|
||||||
|
{title}
|
||||||
|
</Heading>
|
||||||
|
</Column>
|
||||||
|
<Column align="right">
|
||||||
|
{warning && <Text style={warningStyle}>Warning</Text>}
|
||||||
|
</Column>
|
||||||
|
</Row>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const titleStyle = {
|
||||||
|
fontSize: "20px",
|
||||||
|
fontWeight: "bold" as const,
|
||||||
|
margin: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
const warningStyle = {
|
||||||
|
backgroundColor: "#ffd966",
|
||||||
|
color: "#7f6000",
|
||||||
|
padding: "1px 12px",
|
||||||
|
borderRadius: "50px",
|
||||||
|
fontSize: "12px",
|
||||||
|
display: "inline-block",
|
||||||
|
margin: 0,
|
||||||
|
};
|
||||||
55
email-templates/emails/api-key-expiring-soon.tsx
Normal file
55
email-templates/emails/api-key-expiring-soon.tsx
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import { Text } from "@react-email/components";
|
||||||
|
import { BaseTemplate } from "../components/base-template";
|
||||||
|
import CardHeader from "../components/card-header";
|
||||||
|
import { sharedPreviewProps, sharedTemplateProps } from "../props";
|
||||||
|
|
||||||
|
interface ApiKeyExpiringData {
|
||||||
|
name: string;
|
||||||
|
apiKeyName: string;
|
||||||
|
expiresAt: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ApiKeyExpiringEmailProps {
|
||||||
|
logoURL: string;
|
||||||
|
appName: string;
|
||||||
|
data: ApiKeyExpiringData;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const ApiKeyExpiringEmail = ({
|
||||||
|
logoURL,
|
||||||
|
appName,
|
||||||
|
data,
|
||||||
|
}: ApiKeyExpiringEmailProps) => (
|
||||||
|
<BaseTemplate logoURL={logoURL} appName={appName}>
|
||||||
|
<CardHeader title="API Key Expiring Soon" warning />
|
||||||
|
<Text>
|
||||||
|
Hello {data.name}, <br />
|
||||||
|
This is a reminder that your API key <strong>
|
||||||
|
{data.apiKeyName}
|
||||||
|
</strong>{" "}
|
||||||
|
will expire on <strong>{data.expiresAt}</strong>.
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<Text>Please generate a new API key if you need continued access.</Text>
|
||||||
|
</BaseTemplate>
|
||||||
|
);
|
||||||
|
|
||||||
|
export default ApiKeyExpiringEmail;
|
||||||
|
|
||||||
|
ApiKeyExpiringEmail.TemplateProps = {
|
||||||
|
...sharedTemplateProps,
|
||||||
|
data: {
|
||||||
|
name: "{{.Data.Name}}",
|
||||||
|
apiKeyName: "{{.Data.APIKeyName}}",
|
||||||
|
expiresAt: '{{.Data.ExpiresAt.Format "2006-01-02 15:04:05 MST"}}',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
ApiKeyExpiringEmail.PreviewProps = {
|
||||||
|
...sharedPreviewProps,
|
||||||
|
data: {
|
||||||
|
name: "Elias Schneider",
|
||||||
|
apiKeyName: "My API Key",
|
||||||
|
expiresAt: "September 30, 2024",
|
||||||
|
},
|
||||||
|
};
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user