Stage 4: lobby & social (matchmaking, friends, blocks, chat+nudge, invitations, profile, email, multi-player drop-out)
Engine: multi-player drop-out-and-continue with a per-game tile disposition (remove default / return), resigned seats skipped and excluded from the win, leaver rack never revealed; 2-player behaviour unchanged. New domains (service/store, no HTTP yet): internal/social (friend request/accept graph, per-user blocks, per-game chat with nudge as a message kind, content filter via mvdan.cc/xurls/v2 + leet/separator normaliser + phone heuristic) and internal/lobby (in-memory variant-keyed matchmaking pool, friend-game invitations invite->accept with lazy 7-day expiry). account gains profile editing and the email confirm-code flow (Mailer seam: SMTP or log mailer). Migration 00003_social.sql + regenerated jet. main wires the new services into the server (accessors for the Stage 6 handlers); robot substitution stays in Stage 5, REST/stream/push in Stage 6/8. Docs (PLAN, ARCHITECTURE, FUNCTIONAL+ru, TESTING, README) updated.
This commit is contained in:
@@ -0,0 +1,278 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/go-jet/jet/v2/qrm"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"scrabble/backend/internal/postgres/jet/backend/model"
|
||||
"scrabble/backend/internal/postgres/jet/backend/table"
|
||||
)
|
||||
|
||||
const (
|
||||
// emailCodeTTL bounds how long an issued confirm-code stays valid.
|
||||
emailCodeTTL = 15 * time.Minute
|
||||
// emailCodeMaxAttempts caps wrong-code submissions before a code is dead.
|
||||
emailCodeMaxAttempts = 5
|
||||
)
|
||||
|
||||
// Errors returned by the email confirm-code flow.
|
||||
var (
|
||||
// ErrInvalidEmail is returned for an unparseable email address.
|
||||
ErrInvalidEmail = errors.New("account: invalid email address")
|
||||
// ErrEmailTaken is returned when the email is already confirmed by another
|
||||
// account; binding it would be a merge, which Stage 10 owns.
|
||||
ErrEmailTaken = errors.New("account: email already confirmed by another account")
|
||||
// ErrAlreadyConfirmed is returned when the email is already confirmed by the
|
||||
// requesting account.
|
||||
ErrAlreadyConfirmed = errors.New("account: email already confirmed for this account")
|
||||
// ErrNoPendingCode is returned when no live confirm-code exists to verify.
|
||||
ErrNoPendingCode = errors.New("account: no pending confirmation code")
|
||||
// ErrCodeExpired is returned when the confirm-code has passed its TTL.
|
||||
ErrCodeExpired = errors.New("account: confirmation code expired")
|
||||
// ErrTooManyAttempts is returned when the code is locked after too many tries.
|
||||
ErrTooManyAttempts = errors.New("account: too many confirmation attempts")
|
||||
// ErrCodeMismatch is returned when the submitted code does not match.
|
||||
ErrCodeMismatch = errors.New("account: confirmation code does not match")
|
||||
)
|
||||
|
||||
// EmailService runs the email confirm-code flow: it issues a 6-digit code over a
|
||||
// Mailer and verifies it, binding a confirmed email identity to the requesting
|
||||
// account. Only the SHA-256 hash of a code is stored (never the plaintext),
|
||||
// matching the session model. Binding an email already confirmed by a different
|
||||
// account is refused (ErrEmailTaken) — merging two accounts is Stage 10 — and
|
||||
// using an email as a login is Stage 6, which reuses this mechanism.
|
||||
type EmailService struct {
|
||||
store *Store
|
||||
mailer Mailer
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// NewEmailService constructs an EmailService over store, sending via mailer.
|
||||
func NewEmailService(store *Store, mailer Mailer) *EmailService {
|
||||
return &EmailService{store: store, mailer: mailer, now: func() time.Time { return time.Now().UTC() }}
|
||||
}
|
||||
|
||||
// RequestCode issues a fresh confirm-code for email to accountID and mails it,
|
||||
// replacing any prior pending code for the same account and address. It returns
|
||||
// ErrInvalidEmail, ErrEmailTaken or ErrAlreadyConfirmed without sending.
|
||||
func (s *EmailService) RequestCode(ctx context.Context, accountID uuid.UUID, email string) error {
|
||||
addr, err := normalizeEmail(email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
owner, ok, err := s.store.confirmedEmailAccount(ctx, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
if owner == accountID {
|
||||
return ErrAlreadyConfirmed
|
||||
}
|
||||
return ErrEmailTaken
|
||||
}
|
||||
code, hash, err := generateCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.store.replacePendingConfirmation(ctx, accountID, addr, hash, s.now().Add(emailCodeTTL)); err != nil {
|
||||
return err
|
||||
}
|
||||
subject := "Your Scrabble confirmation code"
|
||||
body := fmt.Sprintf("Your confirmation code is %s. It expires in %d minutes.", code, int(emailCodeTTL/time.Minute))
|
||||
return s.mailer.Send(ctx, addr, subject, body)
|
||||
}
|
||||
|
||||
// ConfirmCode verifies code for accountID and email. On success it attaches a
|
||||
// confirmed email identity and returns the account. It returns ErrNoPendingCode,
|
||||
// ErrCodeExpired, ErrTooManyAttempts, ErrCodeMismatch (counting the attempt), or
|
||||
// ErrEmailTaken if the address was confirmed elsewhere in the meantime.
|
||||
func (s *EmailService) ConfirmCode(ctx context.Context, accountID uuid.UUID, email, code string) (Account, error) {
|
||||
addr, err := normalizeEmail(email)
|
||||
if err != nil {
|
||||
return Account{}, err
|
||||
}
|
||||
conf, err := s.store.latestPendingConfirmation(ctx, accountID, addr)
|
||||
if err != nil {
|
||||
return Account{}, err
|
||||
}
|
||||
if s.now().After(conf.expiresAt) {
|
||||
return Account{}, ErrCodeExpired
|
||||
}
|
||||
if conf.attempts >= emailCodeMaxAttempts {
|
||||
return Account{}, ErrTooManyAttempts
|
||||
}
|
||||
if hashCode(code) != conf.codeHash {
|
||||
if err := s.store.bumpConfirmationAttempts(ctx, conf.id); err != nil {
|
||||
return Account{}, err
|
||||
}
|
||||
return Account{}, ErrCodeMismatch
|
||||
}
|
||||
if err := s.store.confirmEmailIdentity(ctx, conf.id, accountID, addr, s.now()); err != nil {
|
||||
return Account{}, err
|
||||
}
|
||||
return s.store.GetByID(ctx, accountID)
|
||||
}
|
||||
|
||||
// emailConfirmation is a pending confirm-code row in domain form.
|
||||
type emailConfirmation struct {
|
||||
id uuid.UUID
|
||||
codeHash string
|
||||
expiresAt time.Time
|
||||
attempts int
|
||||
}
|
||||
|
||||
// confirmedEmailAccount returns the account that holds a confirmed email identity
|
||||
// for email and true, or (zero, false) when none does.
|
||||
func (s *Store) confirmedEmailAccount(ctx context.Context, email string) (uuid.UUID, bool, error) {
|
||||
stmt := postgres.SELECT(table.Identities.AccountID).
|
||||
FROM(table.Identities).
|
||||
WHERE(
|
||||
table.Identities.Kind.EQ(postgres.String(KindEmail)).
|
||||
AND(table.Identities.ExternalID.EQ(postgres.String(email))).
|
||||
AND(table.Identities.Confirmed.EQ(postgres.Bool(true))),
|
||||
).LIMIT(1)
|
||||
var row model.Identities
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return uuid.UUID{}, false, nil
|
||||
}
|
||||
return uuid.UUID{}, false, fmt.Errorf("account: confirmed email owner %s: %w", email, err)
|
||||
}
|
||||
return row.AccountID, true, nil
|
||||
}
|
||||
|
||||
// replacePendingConfirmation clears any pending code for (accountID, email) and
|
||||
// inserts a fresh one, inside one transaction.
|
||||
func (s *Store) replacePendingConfirmation(ctx context.Context, accountID uuid.UUID, email, codeHash string, expiresAt time.Time) error {
|
||||
id, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
return fmt.Errorf("account: new confirmation id: %w", err)
|
||||
}
|
||||
return withTx(ctx, s.db, func(tx *sql.Tx) error {
|
||||
del := table.EmailConfirmations.DELETE().WHERE(
|
||||
table.EmailConfirmations.AccountID.EQ(postgres.UUID(accountID)).
|
||||
AND(table.EmailConfirmations.Email.EQ(postgres.String(email))).
|
||||
AND(table.EmailConfirmations.ConsumedAt.IS_NULL()),
|
||||
)
|
||||
if _, err := del.ExecContext(ctx, tx); err != nil {
|
||||
return fmt.Errorf("clear pending confirmations: %w", err)
|
||||
}
|
||||
ins := table.EmailConfirmations.INSERT(
|
||||
table.EmailConfirmations.ConfirmationID, table.EmailConfirmations.AccountID,
|
||||
table.EmailConfirmations.Email, table.EmailConfirmations.CodeHash, table.EmailConfirmations.ExpiresAt,
|
||||
).VALUES(id, accountID, email, codeHash, expiresAt)
|
||||
if _, err := ins.ExecContext(ctx, tx); err != nil {
|
||||
return fmt.Errorf("insert confirmation: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// latestPendingConfirmation loads the newest unconsumed confirm-code for
|
||||
// (accountID, email), or ErrNoPendingCode.
|
||||
func (s *Store) latestPendingConfirmation(ctx context.Context, accountID uuid.UUID, email string) (emailConfirmation, error) {
|
||||
stmt := postgres.SELECT(table.EmailConfirmations.AllColumns).
|
||||
FROM(table.EmailConfirmations).
|
||||
WHERE(
|
||||
table.EmailConfirmations.AccountID.EQ(postgres.UUID(accountID)).
|
||||
AND(table.EmailConfirmations.Email.EQ(postgres.String(email))).
|
||||
AND(table.EmailConfirmations.ConsumedAt.IS_NULL()),
|
||||
).ORDER_BY(table.EmailConfirmations.CreatedAt.DESC()).LIMIT(1)
|
||||
var row model.EmailConfirmations
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return emailConfirmation{}, ErrNoPendingCode
|
||||
}
|
||||
return emailConfirmation{}, fmt.Errorf("account: load confirmation: %w", err)
|
||||
}
|
||||
return emailConfirmation{
|
||||
id: row.ConfirmationID,
|
||||
codeHash: row.CodeHash,
|
||||
expiresAt: row.ExpiresAt,
|
||||
attempts: int(row.Attempts),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// bumpConfirmationAttempts increments a code's wrong-attempt counter by one.
|
||||
func (s *Store) bumpConfirmationAttempts(ctx context.Context, id uuid.UUID) error {
|
||||
stmt := table.EmailConfirmations.
|
||||
UPDATE(table.EmailConfirmations.Attempts).
|
||||
SET(table.EmailConfirmations.Attempts.ADD(postgres.Int(1))).
|
||||
WHERE(table.EmailConfirmations.ConfirmationID.EQ(postgres.UUID(id)))
|
||||
if _, err := stmt.ExecContext(ctx, s.db); err != nil {
|
||||
return fmt.Errorf("account: bump confirmation attempts: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// confirmEmailIdentity consumes the code and inserts a confirmed email identity,
|
||||
// inside one transaction. A unique-constraint violation means the address was
|
||||
// confirmed by another account first, surfaced as ErrEmailTaken.
|
||||
func (s *Store) confirmEmailIdentity(ctx context.Context, confirmationID, accountID uuid.UUID, email string, now time.Time) error {
|
||||
identityID, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
return fmt.Errorf("account: new identity id: %w", err)
|
||||
}
|
||||
err = withTx(ctx, s.db, func(tx *sql.Tx) error {
|
||||
upd := table.EmailConfirmations.
|
||||
UPDATE(table.EmailConfirmations.ConsumedAt).
|
||||
SET(postgres.TimestampzT(now)).
|
||||
WHERE(table.EmailConfirmations.ConfirmationID.EQ(postgres.UUID(confirmationID)))
|
||||
if _, err := upd.ExecContext(ctx, tx); err != nil {
|
||||
return fmt.Errorf("consume confirmation: %w", err)
|
||||
}
|
||||
ins := table.Identities.INSERT(
|
||||
table.Identities.IdentityID, table.Identities.AccountID, table.Identities.Kind,
|
||||
table.Identities.ExternalID, table.Identities.Confirmed,
|
||||
).VALUES(identityID, accountID, KindEmail, email, true)
|
||||
if _, err := ins.ExecContext(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return ErrEmailTaken
|
||||
}
|
||||
return fmt.Errorf("account: confirm email identity: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// normalizeEmail parses and lower-cases an email address, or returns ErrInvalidEmail.
|
||||
func normalizeEmail(email string) (string, error) {
|
||||
addr, err := mail.ParseAddress(strings.TrimSpace(email))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %q", ErrInvalidEmail, email)
|
||||
}
|
||||
return strings.ToLower(addr.Address), nil
|
||||
}
|
||||
|
||||
// generateCode returns a random 6-digit code and its SHA-256 hex hash.
|
||||
func generateCode() (code, hash string, err error) {
|
||||
n, err := crand.Int(crand.Reader, big.NewInt(1_000_000))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("account: generate code: %w", err)
|
||||
}
|
||||
code = fmt.Sprintf("%06d", n.Int64())
|
||||
return code, hashCode(code), nil
|
||||
}
|
||||
|
||||
// hashCode returns the hex-encoded SHA-256 of a confirm-code.
|
||||
func hashCode(code string) string {
|
||||
sum := sha256.Sum256([]byte(code))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"lowercases", "User@Example.COM", "user@example.com", false},
|
||||
{"trims", " a@b.io ", "a@b.io", false},
|
||||
{"strips display name", "Jane Doe <jane@x.org>", "jane@x.org", false},
|
||||
{"empty", "", "", true},
|
||||
{"no at sign", "notanemail", "", true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got, err := normalizeEmail(tc.in)
|
||||
if tc.wantErr {
|
||||
if !errors.Is(err, ErrInvalidEmail) {
|
||||
t.Fatalf("err = %v, want ErrInvalidEmail", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != tc.want {
|
||||
t.Errorf("got %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeFormat(t *testing.T) {
|
||||
sixDigits := regexp.MustCompile(`^\d{6}$`)
|
||||
for range 50 {
|
||||
code, hash, err := generateCode()
|
||||
if err != nil {
|
||||
t.Fatalf("generate: %v", err)
|
||||
}
|
||||
if !sixDigits.MatchString(code) {
|
||||
t.Fatalf("code %q is not exactly six digits", code)
|
||||
}
|
||||
if hash != hashCode(code) {
|
||||
t.Errorf("returned hash does not match hashCode(%q)", code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashCodeStable(t *testing.T) {
|
||||
if hashCode("123456") != hashCode("123456") {
|
||||
t.Fatal("hashCode is not deterministic")
|
||||
}
|
||||
if hashCode("123456") == hashCode("654321") {
|
||||
t.Fatal("distinct codes must not share a hash")
|
||||
}
|
||||
if got := len(hashCode("000000")); got != 64 {
|
||||
t.Errorf("hex SHA-256 length = %d, want 64", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/smtp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Mailer delivers a transactional email. It is the seam behind which the email
|
||||
// confirm-code flow sends codes, so the relay is swappable and unit tests use a
|
||||
// fixture (see docs/TESTING.md: no real network in tests). The context is offered
|
||||
// for cancellation; the standard-library SMTP implementation sends synchronously
|
||||
// and ignores it.
|
||||
type Mailer interface {
|
||||
Send(ctx context.Context, to, subject, body string) error
|
||||
}
|
||||
|
||||
// SMTPConfig configures the SMTP relay. An empty Host selects the LogMailer
|
||||
// instead, so a deployment without a relay still runs (the code lands in the log).
|
||||
type SMTPConfig struct {
|
||||
Host string
|
||||
Port string
|
||||
Username string
|
||||
Password string
|
||||
From string
|
||||
}
|
||||
|
||||
// SMTPMailer sends mail through an SMTP relay using the standard library. When a
|
||||
// username is set it authenticates with PLAIN; otherwise it relays unauthenticated.
|
||||
type SMTPMailer struct {
|
||||
cfg SMTPConfig
|
||||
}
|
||||
|
||||
// NewSMTPMailer constructs an SMTPMailer for cfg.
|
||||
func NewSMTPMailer(cfg SMTPConfig) SMTPMailer {
|
||||
return SMTPMailer{cfg: cfg}
|
||||
}
|
||||
|
||||
// Send delivers a plain-text UTF-8 message to to via the configured relay.
|
||||
func (m SMTPMailer) Send(_ context.Context, to, subject, body string) error {
|
||||
addr := net.JoinHostPort(m.cfg.Host, m.cfg.Port)
|
||||
var auth smtp.Auth
|
||||
if m.cfg.Username != "" {
|
||||
auth = smtp.PlainAuth("", m.cfg.Username, m.cfg.Password, m.cfg.Host)
|
||||
}
|
||||
if err := smtp.SendMail(addr, auth, m.cfg.From, []string{to}, message(m.cfg.From, to, subject, body)); err != nil {
|
||||
return fmt.Errorf("account: send mail to %s: %w", to, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// message renders a minimal RFC 5322 plain-text email.
|
||||
func message(from, to, subject, body string) []byte {
|
||||
return []byte("From: " + from + "\r\n" +
|
||||
"To: " + to + "\r\n" +
|
||||
"Subject: " + subject + "\r\n" +
|
||||
"MIME-Version: 1.0\r\n" +
|
||||
"Content-Type: text/plain; charset=UTF-8\r\n" +
|
||||
"\r\n" + body + "\r\n")
|
||||
}
|
||||
|
||||
// LogMailer logs the message instead of sending it. It is the default when no
|
||||
// SMTP relay is configured and is intended for development only: it logs the body,
|
||||
// which carries the confirm-code, so it must not be used in production.
|
||||
type LogMailer struct {
|
||||
log *zap.Logger
|
||||
}
|
||||
|
||||
// NewLogMailer constructs a LogMailer that logs through log.
|
||||
func NewLogMailer(log *zap.Logger) LogMailer {
|
||||
return LogMailer{log: log}
|
||||
}
|
||||
|
||||
// Send logs the message at info level and reports success.
|
||||
func (m LogMailer) Send(_ context.Context, to, subject, body string) error {
|
||||
if m.log != nil {
|
||||
m.log.Info("email not sent (log mailer)",
|
||||
zap.String("to", to), zap.String("subject", subject), zap.String("body", body))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/go-jet/jet/v2/postgres"
|
||||
"github.com/go-jet/jet/v2/qrm"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"scrabble/backend/internal/postgres/jet/backend/model"
|
||||
"scrabble/backend/internal/postgres/jet/backend/table"
|
||||
)
|
||||
|
||||
// maxDisplayName caps a display name's length in runes.
|
||||
const maxDisplayName = 64
|
||||
|
||||
// ErrInvalidProfile is returned when a profile update carries an unacceptable
|
||||
// field (an unknown language, an invalid timezone, or an over-long display name).
|
||||
var ErrInvalidProfile = errors.New("account: invalid profile")
|
||||
|
||||
// ProfileUpdate is the full set of player-editable profile fields. UpdateProfile
|
||||
// overwrites every field, so callers send the complete desired profile. AwayStart
|
||||
// and AwayEnd carry only the hour and minute of the daily away window, in the
|
||||
// account's TimeZone.
|
||||
type ProfileUpdate struct {
|
||||
DisplayName string
|
||||
PreferredLanguage string // "en" or "ru"
|
||||
TimeZone string // an IANA location name
|
||||
AwayStart time.Time
|
||||
AwayEnd time.Time
|
||||
BlockChat bool
|
||||
BlockFriendRequests bool
|
||||
}
|
||||
|
||||
// UpdateProfile validates and overwrites the editable fields of the account, then
|
||||
// returns the stored row. It reports ErrInvalidProfile for a bad language,
|
||||
// timezone or display name and ErrNotFound when no account matches id.
|
||||
func (s *Store) UpdateProfile(ctx context.Context, id uuid.UUID, p ProfileUpdate) (Account, error) {
|
||||
lang := strings.TrimSpace(p.PreferredLanguage)
|
||||
if lang != "en" && lang != "ru" {
|
||||
return Account{}, fmt.Errorf("%w: preferred_language %q", ErrInvalidProfile, p.PreferredLanguage)
|
||||
}
|
||||
tz := strings.TrimSpace(p.TimeZone)
|
||||
if _, err := time.LoadLocation(tz); err != nil {
|
||||
return Account{}, fmt.Errorf("%w: time_zone %q: %v", ErrInvalidProfile, p.TimeZone, err)
|
||||
}
|
||||
name := strings.TrimSpace(p.DisplayName)
|
||||
if utf8.RuneCountInString(name) > maxDisplayName {
|
||||
return Account{}, fmt.Errorf("%w: display name exceeds %d characters", ErrInvalidProfile, maxDisplayName)
|
||||
}
|
||||
|
||||
stmt := table.Accounts.UPDATE(
|
||||
table.Accounts.DisplayName, table.Accounts.PreferredLanguage, table.Accounts.TimeZone,
|
||||
table.Accounts.AwayStart, table.Accounts.AwayEnd,
|
||||
table.Accounts.BlockChat, table.Accounts.BlockFriendRequests, table.Accounts.UpdatedAt,
|
||||
).SET(
|
||||
postgres.String(name), postgres.String(lang), postgres.String(tz),
|
||||
postgres.TimeT(p.AwayStart), postgres.TimeT(p.AwayEnd),
|
||||
postgres.Bool(p.BlockChat), postgres.Bool(p.BlockFriendRequests), postgres.TimestampzT(time.Now().UTC()),
|
||||
).WHERE(table.Accounts.AccountID.EQ(postgres.UUID(id))).
|
||||
RETURNING(table.Accounts.AllColumns)
|
||||
|
||||
var row model.Accounts
|
||||
if err := stmt.QueryContext(ctx, s.db, &row); err != nil {
|
||||
if errors.Is(err, qrm.ErrNoRows) {
|
||||
return Account{}, ErrNotFound
|
||||
}
|
||||
return Account{}, fmt.Errorf("account: update profile %s: %w", id, err)
|
||||
}
|
||||
return modelToAccount(row), nil
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestUpdateProfileValidation checks that bad fields are rejected before any
|
||||
// database access, so a nil-backed Store is enough to exercise the guards.
|
||||
func TestUpdateProfileValidation(t *testing.T) {
|
||||
s := &Store{}
|
||||
base := ProfileUpdate{DisplayName: "Kaya", PreferredLanguage: "en", TimeZone: "UTC"}
|
||||
tests := []struct {
|
||||
name string
|
||||
mut func(p *ProfileUpdate)
|
||||
}{
|
||||
{"unknown language", func(p *ProfileUpdate) { p.PreferredLanguage = "fr" }},
|
||||
{"invalid timezone", func(p *ProfileUpdate) { p.TimeZone = "Mars/Olympus" }},
|
||||
{"over-long name", func(p *ProfileUpdate) { p.DisplayName = strings.Repeat("x", maxDisplayName+1) }},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := base
|
||||
tc.mut(&p)
|
||||
if _, err := s.UpdateProfile(context.Background(), uuid.New(), p); !errors.Is(err, ErrInvalidProfile) {
|
||||
t.Fatalf("err = %v, want ErrInvalidProfile", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user