mirror of
https://github.com/pocket-id/pocket-id.git
synced 2025-12-06 05:12:57 +03:00
191 lines
4.6 KiB
Go
191 lines
4.6 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
awscfg "github.com/aws/aws-sdk-go-v2/config"
|
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
|
"github.com/aws/aws-sdk-go-v2/service/s3"
|
|
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
|
|
"github.com/aws/smithy-go"
|
|
)
|
|
|
|
type S3Config struct {
|
|
Bucket string
|
|
Region string
|
|
Endpoint string
|
|
AccessKeyID string
|
|
SecretAccessKey string
|
|
ForcePathStyle bool
|
|
DisableDefaultIntegrityChecks bool
|
|
Root string
|
|
}
|
|
|
|
type s3Storage struct {
|
|
client *s3.Client
|
|
bucket string
|
|
prefix string
|
|
}
|
|
|
|
func NewS3Storage(ctx context.Context, cfg S3Config) (FileStorage, error) {
|
|
creds := credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, "")
|
|
awsCfg, err := awscfg.LoadDefaultConfig(ctx, awscfg.WithRegion(cfg.Region), awscfg.WithCredentialsProvider(creds))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load AWS configuration: %w", err)
|
|
}
|
|
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
|
if cfg.Endpoint != "" {
|
|
o.BaseEndpoint = aws.String(cfg.Endpoint)
|
|
}
|
|
o.UsePathStyle = cfg.ForcePathStyle
|
|
if cfg.DisableDefaultIntegrityChecks {
|
|
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
|
o.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired
|
|
}
|
|
})
|
|
|
|
return &s3Storage{
|
|
client: client,
|
|
bucket: cfg.Bucket,
|
|
prefix: strings.Trim(cfg.Root, "/"),
|
|
}, nil
|
|
}
|
|
|
|
func (s *s3Storage) Type() string {
|
|
return TypeS3
|
|
}
|
|
|
|
func (s *s3Storage) Save(ctx context.Context, path string, data io.Reader) error {
|
|
_, err := s.client.PutObject(ctx, &s3.PutObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(s.buildObjectKey(path)),
|
|
Body: data,
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (s *s3Storage) Open(ctx context.Context, path string) (io.ReadCloser, int64, error) {
|
|
resp, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(s.buildObjectKey(path)),
|
|
})
|
|
if err != nil {
|
|
if isS3NotFound(err) {
|
|
return nil, 0, fs.ErrNotExist
|
|
}
|
|
return nil, 0, err
|
|
}
|
|
return resp.Body, aws.ToInt64(resp.ContentLength), nil
|
|
}
|
|
|
|
func (s *s3Storage) Delete(ctx context.Context, path string) error {
|
|
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Key: aws.String(s.buildObjectKey(path)),
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (s *s3Storage) DeleteAll(ctx context.Context, path string) error {
|
|
|
|
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
|
|
Bucket: aws.String(s.bucket),
|
|
Prefix: aws.String(s.buildObjectKey(path)),
|
|
})
|
|
for paginator.HasMorePages() {
|
|
page, err := paginator.NextPage(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(page.Contents) == 0 {
|
|
continue
|
|
}
|
|
objects := make([]s3types.ObjectIdentifier, 0, len(page.Contents))
|
|
for _, obj := range page.Contents {
|
|
objects = append(objects, s3types.ObjectIdentifier{Key: obj.Key})
|
|
}
|
|
_, err = s.client.DeleteObjects(ctx, &s3.DeleteObjectsInput{
|
|
Bucket: aws.String(s.bucket),
|
|
Delete: &s3types.Delete{Objects: objects, Quiet: aws.Bool(true)},
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *s3Storage) List(ctx context.Context, path string) ([]ObjectInfo, error) {
|
|
paginator := s3.NewListObjectsV2Paginator(s.client, &s3.ListObjectsV2Input{
|
|
Bucket: aws.String(s.bucket),
|
|
Prefix: aws.String(s.buildObjectKey(path)),
|
|
})
|
|
var objects []ObjectInfo
|
|
for paginator.HasMorePages() {
|
|
page, err := paginator.NextPage(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, obj := range page.Contents {
|
|
if obj.Key == nil {
|
|
continue
|
|
}
|
|
objects = append(objects, ObjectInfo{
|
|
Path: aws.ToString(obj.Key),
|
|
Size: aws.ToInt64(obj.Size),
|
|
ModTime: aws.ToTime(obj.LastModified),
|
|
})
|
|
}
|
|
}
|
|
return objects, nil
|
|
}
|
|
|
|
func (s *s3Storage) Walk(ctx context.Context, root string, fn func(ObjectInfo) error) error {
|
|
objects, err := s.List(ctx, root)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, obj := range objects {
|
|
if err := fn(obj); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *s3Storage) buildObjectKey(p string) string {
|
|
p = filepath.Clean(p)
|
|
p = filepath.ToSlash(p)
|
|
p = strings.Trim(p, "/")
|
|
|
|
if p == "" || p == "." {
|
|
return s.prefix
|
|
}
|
|
|
|
if s.prefix == "" {
|
|
return p
|
|
}
|
|
|
|
return s.prefix + "/" + p
|
|
}
|
|
|
|
func isS3NotFound(err error) bool {
|
|
var apiErr smithy.APIError
|
|
if errors.As(err, &apiErr) {
|
|
if apiErr.ErrorCode() == "NotFound" || apiErr.ErrorCode() == "NoSuchKey" {
|
|
return true
|
|
}
|
|
}
|
|
var missingKey *s3types.NoSuchKey
|
|
return errors.As(err, &missingKey)
|
|
}
|