diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..efa69d2 --- /dev/null +++ b/.env.example @@ -0,0 +1,6 @@ +DATABASE_URL=postgres://tracker:tracker@localhost:5432/tracker?sslmode=disable +JWT_SECRET=dev-secret-change-me +FCM_CREDENTIALS_FILE= +LOCATION_RETENTION_DAYS=30 +WS_PING_INTERVAL=30s +PORT=8080 diff --git a/.gitignore b/.gitignore index 276e69c..5fe8d23 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ .worktrees/ .superpowers/ +.env +*.exe +/server/tracker diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..8b670d9 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,12 @@ +FROM golang:1.25-alpine AS builder +WORKDIR /build +COPY server/go.mod ./ +COPY server/go.sum* ./ +RUN go mod download +COPY server/ . +RUN CGO_ENABLED=0 go build -o /tracker ./cmd/tracker + +FROM alpine:3.19 +RUN apk add --no-cache ca-certificates +COPY --from=builder /tracker /usr/local/bin/tracker +ENTRYPOINT ["tracker"] diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..6f0c54c --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,33 @@ +services: + postgres: + image: postgis/postgis:16-3.4 + environment: + POSTGRES_USER: tracker + POSTGRES_PASSWORD: tracker + POSTGRES_DB: tracker + ports: + - "5434:5432" + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U tracker"] + interval: 5s + timeout: 3s + retries: 5 + + tracker-server: + build: + context: . + dockerfile: Dockerfile + ports: + - "8080:8080" + environment: + DATABASE_URL: postgres://tracker:tracker@postgres:5432/tracker?sslmode=disable + JWT_SECRET: dev-secret-change-me + PORT: "8080" + depends_on: + postgres: + condition: service_healthy + +volumes: + pgdata: diff --git a/server/cmd/tracker/main.go b/server/cmd/tracker/main.go new file mode 100644 index 0000000..0b72e66 --- /dev/null +++ b/server/cmd/tracker/main.go @@ -0,0 +1,118 @@ +package main + +import ( + "context" + "log" + "net/http" + "os" + "os/signal" + "strconv" + "syscall" + "time" + + "github.com/nschatz/tracker/server/internal/api" + "github.com/nschatz/tracker/server/internal/auth" + "github.com/nschatz/tracker/server/internal/geo" + "github.com/nschatz/tracker/server/internal/notify" + "github.com/nschatz/tracker/server/internal/store" + "github.com/nschatz/tracker/server/internal/ws" +) + +func main() { + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + port := envOrDefault("PORT", "8080") + dbURL := requireEnv("DATABASE_URL") + jwtSecret := requireEnv("JWT_SECRET") + fcmCreds := os.Getenv("FCM_CREDENTIALS_FILE") + retentionDays := envIntOrDefault("LOCATION_RETENTION_DAYS", 30) + + db, err := store.New(ctx, dbURL) + if err != nil { + log.Fatalf("database: %v", err) + } + defer db.Close() + + a := auth.New(jwtSecret) + hub := ws.NewHub() + go hub.Run() + + var sender notify.Sender + if fcmCreds != "" { + s, err := notify.NewFCMSender(ctx, fcmCreds) + if err != nil { + log.Fatalf("fcm: %v", err) + } + sender = s + } else { + log.Println("WARNING: FCM_CREDENTIALS_FILE not set, using noop sender") + sender = notify.NoopSender{} + } + notifier := notify.NewNotifier(sender) + geoTracker := geo.NewTracker() + + srv := api.NewServer(a, db, db, db, db, hub, geoTracker, notifier, db, db) + + go runRetention(ctx, db, retentionDays) + + httpSrv := &http.Server{Addr: ":" + port, Handler: srv} + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + httpSrv.Shutdown(shutdownCtx) + }() + + log.Printf("listening on :%s", port) + if err := httpSrv.ListenAndServe(); err != http.ErrServerClosed { + log.Fatalf("http: %v", err) + } +} + +func runRetention(ctx context.Context, db interface { + DeleteLocationsOlderThan(context.Context, int) (int64, error) +}, days int) { + ticker := time.NewTicker(24 * time.Hour) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + count, err := db.DeleteLocationsOlderThan(ctx, days) + if err != nil { + log.Printf("retention: %v", err) + } else if count > 0 { + log.Printf("retention: deleted %d old location rows", count) + } + } + } +} + +func envOrDefault(key, fallback string) string { + if v := os.Getenv(key); v != "" { + return v + } + return fallback +} + +func requireEnv(key string) string { + v := os.Getenv(key) + if v == "" { + log.Fatalf("required env var %s is not set", key) + } + return v +} + +func envIntOrDefault(key string, fallback int) int { + v := os.Getenv(key) + if v == "" { + return fallback + } + n, err := strconv.Atoi(v) + if err != nil { + log.Fatalf("env var %s must be an integer: %v", key, err) + } + return n +} diff --git a/server/go.mod b/server/go.mod new file mode 100644 index 0000000..91e168c --- /dev/null +++ b/server/go.mod @@ -0,0 +1,68 @@ +module github.com/nschatz/tracker/server + +go 1.25.0 + +require ( + cel.dev/expr v0.25.1 // indirect + cloud.google.com/go v0.123.0 // indirect + cloud.google.com/go/auth v0.18.2 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect + cloud.google.com/go/compute/metadata v0.9.0 // indirect + cloud.google.com/go/firestore v1.21.0 // indirect + cloud.google.com/go/iam v1.5.3 // indirect + cloud.google.com/go/longrunning v0.8.0 // indirect + cloud.google.com/go/monitoring v1.24.3 // indirect + cloud.google.com/go/storage v1.56.0 // indirect + firebase.google.com/go/v4 v4.19.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.53.0 // indirect + github.com/MicahParks/keyfunc v1.9.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 // indirect + github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect + github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-chi/chi/v5 v5.2.5 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/golang-jwt/jwt/v4 v4.5.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/google/s2a-go v0.1.9 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect + github.com/googleapis/gax-go/v2 v2.19.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.9.1 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect + github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect + github.com/zeebo/errs v1.4.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/contrib/detectors/gcp v1.39.0 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect + go.opentelemetry.io/otel v1.42.0 // indirect + go.opentelemetry.io/otel/metric v1.42.0 // indirect + go.opentelemetry.io/otel/sdk v1.42.0 // indirect + go.opentelemetry.io/otel/sdk/metric v1.42.0 // indirect + go.opentelemetry.io/otel/trace v1.42.0 // indirect + golang.org/x/crypto v0.49.0 // indirect + golang.org/x/net v0.52.0 // indirect + golang.org/x/oauth2 v0.36.0 // indirect + golang.org/x/sync v0.20.0 // indirect + golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.35.0 // indirect + golang.org/x/time v0.15.0 // indirect + google.golang.org/api v0.273.1 // indirect + google.golang.org/appengine/v2 v2.0.6 // indirect + google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260316180232-0b37fe3546d5 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 // indirect + google.golang.org/grpc v1.79.3 // indirect + google.golang.org/protobuf v1.36.11 // indirect + nhooyr.io/websocket v1.8.17 // indirect +) diff --git a/server/go.sum b/server/go.sum new file mode 100644 index 0000000..f156a2c --- /dev/null +++ b/server/go.sum @@ -0,0 +1,166 @@ +cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4= +cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= +cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= +cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= +cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM= +cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +cloud.google.com/go/firestore v1.21.0 h1:BhopUsx7kh6NFx77ccRsHhrtkbJUmDAxNY3uapWdjcM= +cloud.google.com/go/firestore v1.21.0/go.mod h1:1xH6HNcnkf/gGyR8udd6pFO4Z7GWJSwLKQMx/u6UrP4= +cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc= +cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU= +cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8= +cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= +cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE= +cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI= +cloud.google.com/go/storage v1.56.0 h1:iixmq2Fse2tqxMbWhLWC9HfBj1qdxqAmiK8/eqtsLxI= +cloud.google.com/go/storage v1.56.0/go.mod h1:Tpuj6t4NweCLzlNbw9Z9iwxEkrSem20AetIeH/shgVU= +firebase.google.com/go/v4 v4.19.0 h1:f5NMlC2YHFsncz00c2+ecBr+ZYlRMhKIhj1z8Iz0lD8= +firebase.google.com/go/v4 v4.19.0/go.mod h1:P7UfBpzc8+Z3MckX79+zsWzKVfpGryr6HLbAe7gCWfs= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 h1:sBEjpZlNHzK1voKq9695PJSX2o5NEXl7/OL3coiIY0c= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0 h1:owcC2UnmsZycprQ5RfRgjydWhuoxg71LUfyiQdijZuM= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.53.0/go.mod h1:ZPpqegjbE99EPKsu3iUWV22A04wzGPcAY/ziSIQEEgs= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.53.0 h1:Ron4zCA/yk6U7WOBXhTJcDpsUBG9npumK6xw2auFltQ= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.53.0/go.mod h1:cSgYe11MCNYunTnRXrKiR/tHc0eoKjICUuWpNZoVCOo= +github.com/MicahParks/keyfunc v1.9.0 h1:lhKd5xrFHLNOWrDc4Tyb/Q1AJ4LCzQ48GVJyVIID3+o= +github.com/MicahParks/keyfunc v1.9.0/go.mod h1:IdnCilugA0O/99dW+/MkvlyrsX8+L8+x95xuVNtM5jw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/Tv1FZd4SCg8axKApyNyRsAt/w= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g= +github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= +github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4= +github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= +github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8= +github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= +github.com/googleapis/gax-go/v2 v2.19.0 h1:fYQaUOiGwll0cGj7jmHT/0nPlcrZDFPrZRhTsoCr8hE= +github.com/googleapis/gax-go/v2 v2.19.0/go.mod h1:w2ROXVdfGEVFXzmlciUU4EdjHgWvB5h2n6x/8XSTTJA= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= +github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo= +github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zeebo/errs v1.4.0 h1:XNdoD/RRMKP7HD0UhJnIzUy74ISdGGxURlYG8HSWSfM= +github.com/zeebo/errs v1.4.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/detectors/gcp v1.39.0 h1:kWRNZMsfBHZ+uHjiH4y7Etn2FK26LAGkNFw7RHv1DhE= +go.opentelemetry.io/contrib/detectors/gcp v1.39.0/go.mod h1:t/OGqzHBa5v6RHZwrDBJ2OirWc+4q/w2fTbLZwAKjTk= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= +go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= +go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= +go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= +go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= +go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= +go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= +go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= +go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= +go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= +go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.273.1 h1:L7G/TmpAMz0nKx/ciAVssVmWQiOF6+pOuXeKrWVsquY= +google.golang.org/api v0.273.1/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew= +google.golang.org/appengine/v2 v2.0.6 h1:LvPZLGuchSBslPBp+LAhihBeGSiRh1myRoYK4NtuBIw= +google.golang.org/appengine/v2 v2.0.6/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI= +google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5 h1:JNfk58HZ8lfmXbYK2vx/UvsqIL59TzByCxPIX4TDmsE= +google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5/go.mod h1:x5julN69+ED4PcFk/XWayw35O0lf/nGa4aNgODCmNmw= +google.golang.org/genproto/googleapis/api v0.0.0-20260316180232-0b37fe3546d5 h1:CogIeEXn4qWYzzQU0QqvYBM8yDF9cFYzDq9ojSpv0Js= +google.golang.org/genproto/googleapis/api v0.0.0-20260316180232-0b37fe3546d5/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= +google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y= +nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/server/internal/api/auth_handlers.go b/server/internal/api/auth_handlers.go new file mode 100644 index 0000000..170e059 --- /dev/null +++ b/server/internal/api/auth_handlers.go @@ -0,0 +1,121 @@ +package api + +import ( + "encoding/json" + "net/http" + + "github.com/nschatz/tracker/server/internal/auth" +) + +type authResponse struct { + Token string `json:"token"` + User struct { + ID string `json:"id"` + Email string `json:"email"` + DisplayName string `json:"display_name"` + } `json:"user"` +} + +func writeJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(v) +} + +func writeError(w http.ResponseWriter, status int, msg string) { + writeJSON(w, status, map[string]string{"error": msg}) +} + +func (s *Server) handleRegister(w http.ResponseWriter, r *http.Request) { + var req struct { + Email string `json:"email"` + DisplayName string `json:"display_name"` + Password string `json:"password"` + InviteCode string `json:"invite_code"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + if req.Email == "" || req.DisplayName == "" || req.Password == "" || req.InviteCode == "" { + writeError(w, http.StatusBadRequest, "email, display_name, password, and invite_code are required") + return + } + + circle, err := s.circles.GetCircleByInviteCode(r.Context(), req.InviteCode) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid invite code") + return + } + + hash, err := auth.HashPassword(req.Password) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to hash password") + return + } + + user, err := s.store.CreateUser(r.Context(), req.Email, req.DisplayName, hash) + if err != nil { + writeError(w, http.StatusConflict, "could not create user") + return + } + + if err := s.circles.AddMember(r.Context(), circle.ID, user.ID, "member"); err != nil { + writeError(w, http.StatusInternalServerError, "could not add user to circle") + return + } + + token, err := s.auth.IssueToken(user.ID) + if err != nil { + writeError(w, http.StatusInternalServerError, "could not issue token") + return + } + + var resp authResponse + resp.Token = token + resp.User.ID = user.ID.String() + resp.User.Email = user.Email + resp.User.DisplayName = user.DisplayName + + writeJSON(w, http.StatusCreated, resp) +} + +func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { + var req struct { + Email string `json:"email"` + Password string `json:"password"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + if req.Email == "" || req.Password == "" { + writeError(w, http.StatusBadRequest, "email and password are required") + return + } + + user, err := s.store.GetUserByEmail(r.Context(), req.Email) + if err != nil { + writeError(w, http.StatusUnauthorized, "invalid credentials") + return + } + + if !auth.CheckPassword(user.PasswordHash, req.Password) { + writeError(w, http.StatusUnauthorized, "invalid credentials") + return + } + + token, err := s.auth.IssueToken(user.ID) + if err != nil { + writeError(w, http.StatusInternalServerError, "could not issue token") + return + } + + var resp authResponse + resp.Token = token + resp.User.ID = user.ID.String() + resp.User.Email = user.Email + resp.User.DisplayName = user.DisplayName + + writeJSON(w, http.StatusOK, resp) +} diff --git a/server/internal/api/auth_handlers_test.go b/server/internal/api/auth_handlers_test.go new file mode 100644 index 0000000..7d07e6e --- /dev/null +++ b/server/internal/api/auth_handlers_test.go @@ -0,0 +1,170 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/auth" + "github.com/nschatz/tracker/server/internal/model" +) + +type mockStore struct { + users map[string]*model.User +} + +func newMockStore() *mockStore { + return &mockStore{ + users: make(map[string]*model.User), + } +} + +func (m *mockStore) CreateUser(_ context.Context, email, displayName, passwordHash string) (*model.User, error) { + u := &model.User{ + ID: uuid.New(), + Email: email, + DisplayName: displayName, + PasswordHash: passwordHash, + CreatedAt: time.Now(), + } + m.users[email] = u + return u, nil +} + +func (m *mockStore) GetUserByEmail(_ context.Context, email string) (*model.User, error) { + u, ok := m.users[email] + if !ok { + return nil, ¬FoundError{email} + } + return u, nil +} + +func (m *mockStore) GetUserByID(_ context.Context, id uuid.UUID) (*model.User, error) { + for _, u := range m.users { + if u.ID == id { + return u, nil + } + } + return nil, ¬FoundError{id.String()} +} + +type notFoundError struct{ key string } + +func (e *notFoundError) Error() string { return "not found: " + e.key } + +func TestRegisterAndLogin(t *testing.T) { + store := newMockStore() + + // Pre-create a circle with invite code "abc123" + circleID := uuid.New() + circleStore := newMockCircleStore() + circleStore.circles["abc123"] = &model.Circle{ + ID: circleID, + Name: "Test Circle", + InviteCode: "abc123", + CreatedBy: uuid.New(), + CreatedAt: time.Now(), + } + circleStore.byID[circleID] = circleStore.circles["abc123"] + + a := auth.New("test-secret") + srv := NewServer(a, store, circleStore, nil, nil, nil, nil, nil, nil, nil) + + // Register + regBody, _ := json.Marshal(map[string]string{ + "email": "alice@example.com", + "display_name": "Alice", + "password": "hunter2", + "invite_code": "abc123", + }) + req := httptest.NewRequest(http.MethodPost, "/auth/register", bytes.NewReader(regBody)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + srv.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("register: want 201, got %d — body: %s", rr.Code, rr.Body.String()) + } + + var regResp authResponse + if err := json.NewDecoder(rr.Body).Decode(®Resp); err != nil { + t.Fatalf("register: decode response: %v", err) + } + if regResp.Token == "" { + t.Fatal("register: expected non-empty token") + } + if regResp.User.ID == "" { + t.Fatal("register: expected non-empty user id") + } + if regResp.User.Email != "alice@example.com" { + t.Fatalf("register: expected email alice@example.com, got %s", regResp.User.Email) + } + if regResp.User.DisplayName != "Alice" { + t.Fatalf("register: expected display_name Alice, got %s", regResp.User.DisplayName) + } + + // Verify member was added to circle (now tracked in circleStore) + if len(circleStore.members) != 1 { + t.Fatalf("register: expected 1 circle member, got %d", len(circleStore.members)) + } + if circleStore.members[0].CircleID != circleID { + t.Fatal("register: member added to wrong circle") + } + + // Login + loginBody, _ := json.Marshal(map[string]string{ + "email": "alice@example.com", + "password": "hunter2", + }) + req2 := httptest.NewRequest(http.MethodPost, "/auth/login", bytes.NewReader(loginBody)) + req2.Header.Set("Content-Type", "application/json") + rr2 := httptest.NewRecorder() + srv.ServeHTTP(rr2, req2) + + if rr2.Code != http.StatusOK { + t.Fatalf("login: want 200, got %d — body: %s", rr2.Code, rr2.Body.String()) + } + + var loginResp authResponse + if err := json.NewDecoder(rr2.Body).Decode(&loginResp); err != nil { + t.Fatalf("login: decode response: %v", err) + } + if loginResp.Token == "" { + t.Fatal("login: expected non-empty token") + } + + // Wrong password should fail + badBody, _ := json.Marshal(map[string]string{ + "email": "alice@example.com", + "password": "wrongpassword", + }) + req3 := httptest.NewRequest(http.MethodPost, "/auth/login", bytes.NewReader(badBody)) + req3.Header.Set("Content-Type", "application/json") + rr3 := httptest.NewRecorder() + srv.ServeHTTP(rr3, req3) + + if rr3.Code != http.StatusUnauthorized { + t.Fatalf("bad login: want 401, got %d", rr3.Code) + } + + // Bad invite code should fail register + badInvite, _ := json.Marshal(map[string]string{ + "email": "bob@example.com", + "display_name": "Bob", + "password": "password", + "invite_code": "invalid", + }) + req4 := httptest.NewRequest(http.MethodPost, "/auth/register", bytes.NewReader(badInvite)) + req4.Header.Set("Content-Type", "application/json") + rr4 := httptest.NewRecorder() + srv.ServeHTTP(rr4, req4) + + if rr4.Code != http.StatusBadRequest { + t.Fatalf("bad invite: want 400, got %d", rr4.Code) + } +} diff --git a/server/internal/api/circle_handlers.go b/server/internal/api/circle_handlers.go new file mode 100644 index 0000000..b47270a --- /dev/null +++ b/server/internal/api/circle_handlers.go @@ -0,0 +1,97 @@ +package api + +import ( + "encoding/json" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/auth" +) + +func (s *Server) handleCreateCircle(w http.ResponseWriter, r *http.Request) { + var req struct { + Name string `json:"name"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + if req.Name == "" { + writeError(w, http.StatusBadRequest, "name is required") + return + } + + userID := auth.UserIDFromContext(r.Context()) + circle, err := s.circles.CreateCircle(r.Context(), req.Name, userID) + if err != nil { + writeError(w, http.StatusInternalServerError, "could not create circle") + return + } + + writeJSON(w, http.StatusCreated, circle) +} + +func (s *Server) handleJoinCircle(w http.ResponseWriter, r *http.Request) { + idStr := chi.URLParam(r, "id") + circleID, err := uuid.Parse(idStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid circle id") + return + } + + var req struct { + InviteCode string `json:"invite_code"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + + circle, err := s.circles.GetCircleByInviteCode(r.Context(), req.InviteCode) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid invite code") + return + } + + if circle.ID != circleID { + writeError(w, http.StatusBadRequest, "invite code does not match circle") + return + } + + userID := auth.UserIDFromContext(r.Context()) + if err := s.circles.AddMember(r.Context(), circleID, userID, "member"); err != nil { + writeError(w, http.StatusInternalServerError, "could not join circle") + return + } + + writeJSON(w, http.StatusOK, map[string]string{"status": "joined"}) +} + +func (s *Server) handleGetMembers(w http.ResponseWriter, r *http.Request) { + idStr := chi.URLParam(r, "id") + circleID, err := uuid.Parse(idStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid circle id") + return + } + + members, err := s.circles.GetMembers(r.Context(), circleID) + if err != nil { + writeError(w, http.StatusInternalServerError, "could not get members") + return + } + + writeJSON(w, http.StatusOK, members) +} + +func (s *Server) handleGetUserCircles(w http.ResponseWriter, r *http.Request) { + userID := auth.UserIDFromContext(r.Context()) + circles, err := s.circles.GetUserCircles(r.Context(), userID) + if err != nil { + writeError(w, http.StatusInternalServerError, "could not get circles") + return + } + + writeJSON(w, http.StatusOK, circles) +} diff --git a/server/internal/api/circle_handlers_test.go b/server/internal/api/circle_handlers_test.go new file mode 100644 index 0000000..5569fa6 --- /dev/null +++ b/server/internal/api/circle_handlers_test.go @@ -0,0 +1,116 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/auth" + "github.com/nschatz/tracker/server/internal/model" +) + +type mockCircleStore struct { + circles map[string]*model.Circle // keyed by invite code + byID map[uuid.UUID]*model.Circle + members []model.CircleMember +} + +func newMockCircleStore() *mockCircleStore { + return &mockCircleStore{ + circles: make(map[string]*model.Circle), + byID: make(map[uuid.UUID]*model.Circle), + } +} + +func (m *mockCircleStore) CreateCircle(_ context.Context, name string, createdBy uuid.UUID) (*model.Circle, error) { + c := &model.Circle{ + ID: uuid.New(), + Name: name, + InviteCode: "testcode", + CreatedBy: createdBy, + CreatedAt: time.Now(), + } + m.byID[c.ID] = c + m.circles[c.InviteCode] = c + return c, nil +} + +func (m *mockCircleStore) GetUserCircles(_ context.Context, userID uuid.UUID) ([]model.Circle, error) { + var result []model.Circle + for _, c := range m.byID { + if c.CreatedBy == userID { + result = append(result, *c) + } + } + return result, nil +} + +func (m *mockCircleStore) GetMembers(_ context.Context, circleID uuid.UUID) ([]model.CircleMember, error) { + var result []model.CircleMember + for _, mem := range m.members { + if mem.CircleID == circleID { + result = append(result, mem) + } + } + return result, nil +} + +func (m *mockCircleStore) GetCircleByInviteCode(_ context.Context, code string) (*model.Circle, error) { + c, ok := m.circles[code] + if !ok { + return nil, ¬FoundError{code} + } + return c, nil +} + +func (m *mockCircleStore) AddMember(_ context.Context, circleID, userID uuid.UUID, role string) error { + m.members = append(m.members, model.CircleMember{ + CircleID: circleID, + UserID: userID, + Role: role, + JoinedAt: time.Now(), + }) + return nil +} + +func TestCreateCircle(t *testing.T) { + authSvc := auth.New("test-secret") + authStore := newMockStore() + circleStore := newMockCircleStore() + + srv := NewServer(authSvc, authStore, circleStore, nil, nil, nil, nil, nil, nil, nil) + + userID := uuid.New() + token, err := authSvc.IssueToken(userID) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + + body, _ := json.Marshal(map[string]string{"name": "My Circle"}) + req := httptest.NewRequest(http.MethodPost, "/circles", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + rr := httptest.NewRecorder() + srv.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("want 201, got %d — body: %s", rr.Code, rr.Body.String()) + } + + var circle model.Circle + if err := json.NewDecoder(rr.Body).Decode(&circle); err != nil { + t.Fatalf("decode response: %v", err) + } + if circle.Name != "My Circle" { + t.Errorf("expected name 'My Circle', got %q", circle.Name) + } + if circle.ID == uuid.Nil { + t.Error("expected non-nil circle ID") + } +} diff --git a/server/internal/api/fcm_handlers.go b/server/internal/api/fcm_handlers.go new file mode 100644 index 0000000..5df8087 --- /dev/null +++ b/server/internal/api/fcm_handlers.go @@ -0,0 +1,31 @@ +package api + +import ( + "encoding/json" + "net/http" + + "github.com/nschatz/tracker/server/internal/auth" +) + +func (s *Server) handleRegisterFCMToken(w http.ResponseWriter, r *http.Request) { + userID := auth.UserIDFromContext(r.Context()) + + var req struct { + Token string `json:"token"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + if req.Token == "" { + writeError(w, http.StatusBadRequest, "token is required") + return + } + + if err := s.fcmTokens.UpsertFCMToken(r.Context(), userID, req.Token); err != nil { + writeError(w, http.StatusInternalServerError, "could not register FCM token") + return + } + + w.WriteHeader(http.StatusOK) +} diff --git a/server/internal/api/geofence_handlers.go b/server/internal/api/geofence_handlers.go new file mode 100644 index 0000000..7369691 --- /dev/null +++ b/server/internal/api/geofence_handlers.go @@ -0,0 +1,106 @@ +package api + +import ( + "encoding/json" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/auth" +) + +func (s *Server) handleCreateGeofence(w http.ResponseWriter, r *http.Request) { + var req struct { + CircleID uuid.UUID `json:"circle_id"` + Name string `json:"name"` + Lat float64 `json:"lat"` + Lng float64 `json:"lng"` + RadiusMeters float32 `json:"radius_meters"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + if req.Name == "" { + writeError(w, http.StatusBadRequest, "name is required") + return + } + if req.RadiusMeters <= 0 { + writeError(w, http.StatusBadRequest, "radius_meters must be greater than 0") + return + } + + userID := auth.UserIDFromContext(r.Context()) + gf, err := s.geofences.CreateGeofence(r.Context(), req.CircleID, req.Name, req.Lat, req.Lng, req.RadiusMeters, userID) + if err != nil { + writeError(w, http.StatusInternalServerError, "could not create geofence") + return + } + + writeJSON(w, http.StatusCreated, gf) +} + +func (s *Server) handleGetGeofences(w http.ResponseWriter, r *http.Request) { + circleIDStr := r.URL.Query().Get("circle_id") + if circleIDStr == "" { + writeError(w, http.StatusBadRequest, "circle_id query parameter is required") + return + } + circleID, err := uuid.Parse(circleIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid circle_id") + return + } + + geofences, err := s.geofences.GetGeofences(r.Context(), circleID) + if err != nil { + writeError(w, http.StatusInternalServerError, "could not get geofences") + return + } + + writeJSON(w, http.StatusOK, geofences) +} + +func (s *Server) handleUpdateGeofence(w http.ResponseWriter, r *http.Request) { + idStr := chi.URLParam(r, "id") + id, err := uuid.Parse(idStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid geofence id") + return + } + + var req struct { + Name string `json:"name"` + Lat float64 `json:"lat"` + Lng float64 `json:"lng"` + RadiusMeters float32 `json:"radius_meters"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + + gf, err := s.geofences.UpdateGeofence(r.Context(), id, req.Name, req.Lat, req.Lng, req.RadiusMeters) + if err != nil { + writeError(w, http.StatusInternalServerError, "could not update geofence") + return + } + + writeJSON(w, http.StatusOK, gf) +} + +func (s *Server) handleDeleteGeofence(w http.ResponseWriter, r *http.Request) { + idStr := chi.URLParam(r, "id") + id, err := uuid.Parse(idStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid geofence id") + return + } + + if err := s.geofences.DeleteGeofence(r.Context(), id); err != nil { + writeError(w, http.StatusInternalServerError, "could not delete geofence") + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/server/internal/api/location_handlers.go b/server/internal/api/location_handlers.go new file mode 100644 index 0000000..2aa86d1 --- /dev/null +++ b/server/internal/api/location_handlers.go @@ -0,0 +1,202 @@ +package api + +import ( + "context" + "encoding/json" + "log" + "net/http" + "time" + + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/auth" + "github.com/nschatz/tracker/server/internal/model" +) + +// handlePostLocations handles POST /locations +// Accepts {"locations": [...]} and returns 202 on success. +func (s *Server) handlePostLocations(w http.ResponseWriter, r *http.Request) { + userID := auth.UserIDFromContext(r.Context()) + if userID == uuid.Nil { + writeError(w, http.StatusUnauthorized, "unauthorized") + return + } + + var req struct { + Locations []model.LocationInput `json:"locations"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid JSON") + return + } + if len(req.Locations) == 0 { + writeError(w, http.StatusBadRequest, "locations must not be empty") + return + } + + if err := s.locations.InsertLocations(r.Context(), userID, req.Locations); err != nil { + writeError(w, http.StatusInternalServerError, "failed to store locations") + return + } + + // Get the latest point from the batch for broadcast and geofence eval + latest := req.Locations[len(req.Locations)-1] + loc := model.Location{ + UserID: userID, + Lat: latest.Lat, + Lng: latest.Lng, + Speed: latest.Speed, + BatteryLevel: latest.BatteryLevel, + Accuracy: latest.Accuracy, + RecordedAt: latest.RecordedAt, + } + + // Run broadcast and geofence evaluation in a goroutine so it doesn't block the response + go s.processLocationUpdate(userID, loc) + + w.WriteHeader(http.StatusAccepted) +} + +// processLocationUpdate broadcasts the location and evaluates geofences. +// Uses a detached context since the request context will be cancelled. +func (s *Server) processLocationUpdate(userID uuid.UUID, loc model.Location) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if s.circles == nil { + return + } + + // Get user's circles + circles, err := s.circles.GetUserCircles(ctx, userID) + if err != nil { + log.Printf("processLocationUpdate: get circles for user %s: %v", userID, err) + return + } + + for _, circle := range circles { + // Broadcast location via WebSocket + if s.hub != nil { + s.hub.BroadcastLocation(circle.ID, loc) + } + + // Evaluate geofences + if s.geoEval == nil || s.geoTracker == nil || s.notifier == nil { + continue + } + + containingIDs, err := s.geoEval.FindContainingGeofences(ctx, circle.ID, loc.Lat, loc.Lng) + if err != nil { + log.Printf("processLocationUpdate: find geofences for circle %s: %v", circle.ID, err) + continue + } + + entered, left := s.geoTracker.Update(userID, containingIDs) + if len(entered) == 0 && len(left) == 0 { + continue + } + + // Get user display name + user, err := s.store.GetUserByID(ctx, userID) + if err != nil { + log.Printf("processLocationUpdate: get user %s: %v", userID, err) + continue + } + + // Get geofences for name lookup + geofences, err := s.geoEval.GetGeofences(ctx, circle.ID) + if err != nil { + log.Printf("processLocationUpdate: get geofences for circle %s: %v", circle.ID, err) + continue + } + geoMap := make(map[uuid.UUID]string, len(geofences)) + for _, g := range geofences { + geoMap[g.ID] = g.Name + } + + // Get FCM tokens for other members + if s.fcmTokens == nil { + continue + } + tokens, err := s.fcmTokens.GetFCMTokensForCircle(ctx, circle.ID, userID) + if err != nil { + log.Printf("processLocationUpdate: get FCM tokens for circle %s: %v", circle.ID, err) + continue + } + if len(tokens) == 0 { + continue + } + + for _, geoID := range entered { + name := geoMap[geoID] + s.notifier.GeofenceEnter(ctx, user.DisplayName, name, tokens) + } + for _, geoID := range left { + name := geoMap[geoID] + s.notifier.GeofenceLeave(ctx, user.DisplayName, name, tokens) + } + } +} + +// handleGetLatestLocations handles GET /locations/latest?circle_id=UUID +// Returns the most recent location for each circle member. +func (s *Server) handleGetLatestLocations(w http.ResponseWriter, r *http.Request) { + circleIDStr := r.URL.Query().Get("circle_id") + if circleIDStr == "" { + writeError(w, http.StatusBadRequest, "circle_id is required") + return + } + circleID, err := uuid.Parse(circleIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid circle_id") + return + } + + locs, err := s.locations.GetLatestLocations(r.Context(), circleID) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to get locations") + return + } + + writeJSON(w, http.StatusOK, locs) +} + +// handleGetHistory handles GET /locations/history?user_id=UUID&from=RFC3339&to=RFC3339 +// Returns location history for a user within the given time range. +func (s *Server) handleGetHistory(w http.ResponseWriter, r *http.Request) { + userIDStr := r.URL.Query().Get("user_id") + if userIDStr == "" { + writeError(w, http.StatusBadRequest, "user_id is required") + return + } + userID, err := uuid.Parse(userIDStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid user_id") + return + } + + fromStr := r.URL.Query().Get("from") + toStr := r.URL.Query().Get("to") + if fromStr == "" || toStr == "" { + writeError(w, http.StatusBadRequest, "from and to are required") + return + } + + from, err := time.Parse(time.RFC3339, fromStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid from time: must be RFC3339") + return + } + to, err := time.Parse(time.RFC3339, toStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid to time: must be RFC3339") + return + } + + locs, err := s.locations.GetHistory(r.Context(), userID, from, to) + if err != nil { + writeError(w, http.StatusInternalServerError, "failed to get history") + return + } + + writeJSON(w, http.StatusOK, locs) +} diff --git a/server/internal/api/location_handlers_test.go b/server/internal/api/location_handlers_test.go new file mode 100644 index 0000000..c59104a --- /dev/null +++ b/server/internal/api/location_handlers_test.go @@ -0,0 +1,158 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/auth" + "github.com/nschatz/tracker/server/internal/model" +) + +type mockLocationStore struct { + inserted []model.LocationInput + forUser uuid.UUID +} + +func (m *mockLocationStore) InsertLocations(_ context.Context, userID uuid.UUID, locs []model.LocationInput) error { + m.forUser = userID + m.inserted = append(m.inserted, locs...) + return nil +} + +func (m *mockLocationStore) GetLatestLocations(_ context.Context, circleID uuid.UUID) ([]model.Location, error) { + return []model.Location{}, nil +} + +func (m *mockLocationStore) GetHistory(_ context.Context, userID uuid.UUID, from, to time.Time) ([]model.Location, error) { + return []model.Location{}, nil +} + +func TestPostLocations(t *testing.T) { + authSvc := auth.New("test-secret") + authStore := newMockStore() + locStore := &mockLocationStore{} + + srv := NewServer(authSvc, authStore, nil, locStore, nil, nil, nil, nil, nil, nil) + + // Create a user and issue a token + userID := uuid.New() + token, err := authSvc.IssueToken(userID) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + + now := time.Now().UTC().Truncate(time.Second) + speed := float32(3.5) + locs := []model.LocationInput{ + {Lat: 40.7128, Lng: -74.0060, Speed: &speed, RecordedAt: now.Add(-time.Minute)}, + {Lat: 40.7130, Lng: -74.0058, RecordedAt: now}, + } + + body, _ := json.Marshal(map[string]any{ + "locations": locs, + }) + + req := httptest.NewRequest(http.MethodPost, "/locations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + rr := httptest.NewRecorder() + srv.ServeHTTP(rr, req) + + if rr.Code != http.StatusAccepted { + t.Fatalf("want 202, got %d — body: %s", rr.Code, rr.Body.String()) + } + + // Verify locations were stored + if len(locStore.inserted) != 2 { + t.Fatalf("expected 2 inserted locations, got %d", len(locStore.inserted)) + } + if locStore.forUser != userID { + t.Errorf("expected userID %v, got %v", userID, locStore.forUser) + } + if locStore.inserted[0].Lat != 40.7128 { + t.Errorf("expected lat 40.7128, got %v", locStore.inserted[0].Lat) + } +} + +func TestPostLocations_Unauthenticated(t *testing.T) { + authSvc := auth.New("test-secret") + authStore := newMockStore() + locStore := &mockLocationStore{} + + srv := NewServer(authSvc, authStore, nil, locStore, nil, nil, nil, nil, nil, nil) + + body, _ := json.Marshal(map[string]any{ + "locations": []model.LocationInput{{Lat: 1.0, Lng: 1.0, RecordedAt: time.Now()}}, + }) + + req := httptest.NewRequest(http.MethodPost, "/locations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + // No Authorization header + + rr := httptest.NewRecorder() + srv.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("want 401, got %d", rr.Code) + } +} + +func TestGetLatestLocations(t *testing.T) { + authSvc := auth.New("test-secret") + authStore := newMockStore() + locStore := &mockLocationStore{} + + srv := NewServer(authSvc, authStore, nil, locStore, nil, nil, nil, nil, nil, nil) + + userID := uuid.New() + token, err := authSvc.IssueToken(userID) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + + circleID := uuid.New() + req := httptest.NewRequest(http.MethodGet, "/locations/latest?circle_id="+circleID.String(), nil) + req.Header.Set("Authorization", "Bearer "+token) + + rr := httptest.NewRecorder() + srv.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("want 200, got %d — body: %s", rr.Code, rr.Body.String()) + } +} + +func TestGetHistory(t *testing.T) { + authSvc := auth.New("test-secret") + authStore := newMockStore() + locStore := &mockLocationStore{} + + srv := NewServer(authSvc, authStore, nil, locStore, nil, nil, nil, nil, nil, nil) + + userID := uuid.New() + token, err := authSvc.IssueToken(userID) + if err != nil { + t.Fatalf("IssueToken: %v", err) + } + + from := time.Now().Add(-time.Hour).UTC().Format(time.RFC3339) + to := time.Now().UTC().Format(time.RFC3339) + url := "/locations/history?user_id=" + userID.String() + "&from=" + from + "&to=" + to + + req := httptest.NewRequest(http.MethodGet, url, nil) + req.Header.Set("Authorization", "Bearer "+token) + + rr := httptest.NewRecorder() + srv.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("want 200, got %d — body: %s", rr.Code, rr.Body.String()) + } +} diff --git a/server/internal/api/server.go b/server/internal/api/server.go new file mode 100644 index 0000000..5fdab6a --- /dev/null +++ b/server/internal/api/server.go @@ -0,0 +1,126 @@ +package api + +import ( + "context" + "net/http" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/auth" + "github.com/nschatz/tracker/server/internal/geo" + "github.com/nschatz/tracker/server/internal/model" + "github.com/nschatz/tracker/server/internal/notify" + "github.com/nschatz/tracker/server/internal/ws" +) + +type AuthStore interface { + CreateUser(ctx context.Context, email, displayName, passwordHash string) (*model.User, error) + GetUserByEmail(ctx context.Context, email string) (*model.User, error) + GetUserByID(ctx context.Context, id uuid.UUID) (*model.User, error) +} + +type CircleStore interface { + CreateCircle(ctx context.Context, name string, createdBy uuid.UUID) (*model.Circle, error) + GetUserCircles(ctx context.Context, userID uuid.UUID) ([]model.Circle, error) + GetMembers(ctx context.Context, circleID uuid.UUID) ([]model.CircleMember, error) + GetCircleByInviteCode(ctx context.Context, code string) (*model.Circle, error) + AddMember(ctx context.Context, circleID, userID uuid.UUID, role string) error +} + +type LocationStore interface { + InsertLocations(ctx context.Context, userID uuid.UUID, locs []model.LocationInput) error + GetLatestLocations(ctx context.Context, circleID uuid.UUID) ([]model.Location, error) + GetHistory(ctx context.Context, userID uuid.UUID, from, to time.Time) ([]model.Location, error) +} + +type GeofenceStore interface { + CreateGeofence(ctx context.Context, circleID uuid.UUID, name string, lat, lng float64, radiusMeters float32, createdBy uuid.UUID) (*model.Geofence, error) + GetGeofences(ctx context.Context, circleID uuid.UUID) ([]model.Geofence, error) + UpdateGeofence(ctx context.Context, id uuid.UUID, name string, lat, lng float64, radiusMeters float32) (*model.Geofence, error) + DeleteGeofence(ctx context.Context, id uuid.UUID) error +} + +type GeoEvaluator interface { + FindContainingGeofences(ctx context.Context, circleID uuid.UUID, lat, lng float64) ([]uuid.UUID, error) + GetGeofences(ctx context.Context, circleID uuid.UUID) ([]model.Geofence, error) +} + +type FCMTokenStore interface { + UpsertFCMToken(ctx context.Context, userID uuid.UUID, token string) error + GetFCMTokensForCircle(ctx context.Context, circleID uuid.UUID, excludeUserID uuid.UUID) ([]string, error) +} + +type Server struct { + router chi.Router + auth *auth.Auth + store AuthStore + circles CircleStore + locations LocationStore + geofences GeofenceStore + hub *ws.Hub + geoTracker *geo.Tracker + notifier *notify.Notifier + geoEval GeoEvaluator + fcmTokens FCMTokenStore +} + +func NewServer(a *auth.Auth, store AuthStore, circles CircleStore, locations LocationStore, geofences GeofenceStore, hub *ws.Hub, geoTracker *geo.Tracker, notifier *notify.Notifier, geoEval GeoEvaluator, fcmTokens FCMTokenStore) *Server { + s := &Server{ + router: chi.NewRouter(), + auth: a, + store: store, + circles: circles, + locations: locations, + geofences: geofences, + hub: hub, + geoTracker: geoTracker, + notifier: notifier, + geoEval: geoEval, + fcmTokens: fcmTokens, + } + + s.router.Use(middleware.Logger) + s.router.Use(middleware.Recoverer) + + s.router.Get("/health", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok\n")) + }) + + s.router.Post("/auth/register", s.handleRegister) + s.router.Post("/auth/login", s.handleLogin) + + s.router.Group(func(r chi.Router) { + r.Use(a.Middleware) + r.Post("/locations", s.handlePostLocations) + r.Get("/locations/latest", s.handleGetLatestLocations) + r.Get("/locations/history", s.handleGetHistory) + + r.Post("/circles", s.handleCreateCircle) + r.Post("/circles/{id}/join", s.handleJoinCircle) + r.Get("/circles/{id}/members", s.handleGetMembers) + r.Get("/circles", s.handleGetUserCircles) + + r.Post("/geofences", s.handleCreateGeofence) + r.Get("/geofences", s.handleGetGeofences) + r.Put("/geofences/{id}", s.handleUpdateGeofence) + r.Delete("/geofences/{id}", s.handleDeleteGeofence) + + r.Get("/ws", s.handleWebSocket) + r.Post("/fcm-token", s.handleRegisterFCMToken) + }) + + return s +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.router.ServeHTTP(w, r) +} + +func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { + userID := auth.UserIDFromContext(r.Context()) + circleIDStr := r.URL.Query().Get("circle_id") + circleID, _ := uuid.Parse(circleIDStr) + s.hub.HandleConnect(w, r, userID, circleID) +} diff --git a/server/internal/auth/auth.go b/server/internal/auth/auth.go new file mode 100644 index 0000000..a86fedd --- /dev/null +++ b/server/internal/auth/auth.go @@ -0,0 +1,69 @@ +package auth + +import ( + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" +) + +type Auth struct { + secret []byte +} + +func New(secret string) *Auth { + return &Auth{secret: []byte(secret)} +} + +func HashPassword(password string) (string, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", fmt.Errorf("hash password: %w", err) + } + return string(hash), nil +} + +func CheckPassword(hash, password string) bool { + return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil +} + +func (a *Auth) IssueToken(userID uuid.UUID) (string, error) { + claims := jwt.MapClaims{ + "sub": userID.String(), + "iat": time.Now().Unix(), + "exp": time.Now().Add(30 * 24 * time.Hour).Unix(), + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := token.SignedString(a.secret) + if err != nil { + return "", fmt.Errorf("sign token: %w", err) + } + return signed, nil +} + +func (a *Auth) ParseToken(tokenStr string) (uuid.UUID, error) { + token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return a.secret, nil + }) + if err != nil { + return uuid.Nil, fmt.Errorf("parse token: %w", err) + } + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return uuid.Nil, fmt.Errorf("invalid token claims") + } + sub, ok := claims["sub"].(string) + if !ok { + return uuid.Nil, fmt.Errorf("missing sub claim") + } + id, err := uuid.Parse(sub) + if err != nil { + return uuid.Nil, fmt.Errorf("parse sub as uuid: %w", err) + } + return id, nil +} diff --git a/server/internal/auth/auth_test.go b/server/internal/auth/auth_test.go new file mode 100644 index 0000000..11e8912 --- /dev/null +++ b/server/internal/auth/auth_test.go @@ -0,0 +1,96 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" +) + +func TestHashAndCheckPassword(t *testing.T) { + password := "supersecret" + hash, err := HashPassword(password) + if err != nil { + t.Fatalf("HashPassword error: %v", err) + } + if !CheckPassword(hash, password) { + t.Error("CheckPassword returned false for correct password") + } + if CheckPassword(hash, "wrongpassword") { + t.Error("CheckPassword returned true for wrong password") + } +} + +func TestIssueAndParseToken(t *testing.T) { + a := New("test-secret") + userID := uuid.New() + + token, err := a.IssueToken(userID) + if err != nil { + t.Fatalf("IssueToken error: %v", err) + } + + parsed, err := a.ParseToken(token) + if err != nil { + t.Fatalf("ParseToken error: %v", err) + } + + if parsed != userID { + t.Errorf("parsed UUID %v does not match original %v", parsed, userID) + } +} + +func TestParseTokenInvalid(t *testing.T) { + a := New("test-secret") + _, err := a.ParseToken("garbage") + if err == nil { + t.Error("expected error parsing invalid token, got nil") + } +} + +func TestMiddleware(t *testing.T) { + a := New("test-secret") + userID := uuid.New() + + token, err := a.IssueToken(userID) + if err != nil { + t.Fatalf("IssueToken error: %v", err) + } + + var capturedID uuid.UUID + handler := a.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedID = UserIDFromContext(r.Context()) + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer "+token) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", rec.Code) + } + if capturedID != userID { + t.Errorf("context UUID %v does not match original %v", capturedID, userID) + } +} + +func TestMiddlewareNoToken(t *testing.T) { + a := New("test-secret") + + handler := a.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("expected status 401, got %d", rec.Code) + } +} diff --git a/server/internal/auth/middleware.go b/server/internal/auth/middleware.go new file mode 100644 index 0000000..9993e17 --- /dev/null +++ b/server/internal/auth/middleware.go @@ -0,0 +1,34 @@ +package auth + +import ( + "context" + "net/http" + "strings" + + "github.com/google/uuid" +) + +type contextKey struct{} + +func (a *Auth) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + header := r.Header.Get("Authorization") + if !strings.HasPrefix(header, "Bearer ") { + http.Error(w, "missing authorization header", http.StatusUnauthorized) + return + } + tokenStr := strings.TrimPrefix(header, "Bearer ") + userID, err := a.ParseToken(tokenStr) + if err != nil { + http.Error(w, "invalid token", http.StatusUnauthorized) + return + } + ctx := context.WithValue(r.Context(), contextKey{}, userID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func UserIDFromContext(ctx context.Context) uuid.UUID { + id, _ := ctx.Value(contextKey{}).(uuid.UUID) + return id +} diff --git a/server/internal/geo/geo.go b/server/internal/geo/geo.go new file mode 100644 index 0000000..ee28a44 --- /dev/null +++ b/server/internal/geo/geo.go @@ -0,0 +1,65 @@ +package geo + +import ( + "sync" + + "github.com/google/uuid" +) + +// Tracker maintains in-memory state of which geofences each user is currently inside. +type Tracker struct { + mu sync.Mutex + state map[uuid.UUID]map[uuid.UUID]struct{} // user_id -> set of geofence_ids +} + +func NewTracker() *Tracker { + return &Tracker{ + state: make(map[uuid.UUID]map[uuid.UUID]struct{}), + } +} + +// Update takes the set of geofences a user is currently inside and returns +// which geofences were entered and which were left since the last update. +func (t *Tracker) Update(userID uuid.UUID, currentGeofences []uuid.UUID) (entered, left []uuid.UUID) { + t.mu.Lock() + defer t.mu.Unlock() + + previous, ok := t.state[userID] + if !ok { + previous = make(map[uuid.UUID]struct{}) + } + + current := make(map[uuid.UUID]struct{}, len(currentGeofences)) + for _, id := range currentGeofences { + current[id] = struct{}{} + } + + // Find entered: in current but not in previous + for id := range current { + if _, wasThere := previous[id]; !wasThere { + entered = append(entered, id) + } + } + + // Find left: in previous but not in current + for id := range previous { + if _, isThere := current[id]; !isThere { + left = append(left, id) + } + } + + t.state[userID] = current + return entered, left +} + +// SetState rebuilds state for a user (used on server startup). +func (t *Tracker) SetState(userID uuid.UUID, geofenceIDs []uuid.UUID) { + t.mu.Lock() + defer t.mu.Unlock() + + s := make(map[uuid.UUID]struct{}, len(geofenceIDs)) + for _, id := range geofenceIDs { + s[id] = struct{}{} + } + t.state[userID] = s +} diff --git a/server/internal/geo/geo_test.go b/server/internal/geo/geo_test.go new file mode 100644 index 0000000..924f28c --- /dev/null +++ b/server/internal/geo/geo_test.go @@ -0,0 +1,50 @@ +package geo + +import ( + "testing" + + "github.com/google/uuid" +) + +func TestDetectTransitions(t *testing.T) { + tracker := NewTracker() + userID := uuid.New() + gfHome := uuid.New() + gfWork := uuid.New() + + // Step 1: User starts nowhere -> enters gfHome + entered, left := tracker.Update(userID, []uuid.UUID{gfHome}) + if len(entered) != 1 || entered[0] != gfHome { + t.Fatalf("step 1: expected entered=[gfHome], got %v", entered) + } + if len(left) != 0 { + t.Fatalf("step 1: expected left=[], got %v", left) + } + + // Step 2: User moves to gfWork -> entered=[gfWork], left=[gfHome] + entered, left = tracker.Update(userID, []uuid.UUID{gfWork}) + if len(entered) != 1 || entered[0] != gfWork { + t.Fatalf("step 2: expected entered=[gfWork], got %v", entered) + } + if len(left) != 1 || left[0] != gfHome { + t.Fatalf("step 2: expected left=[gfHome], got %v", left) + } + + // Step 3: User stays at gfWork -> entered=[], left=[] + entered, left = tracker.Update(userID, []uuid.UUID{gfWork}) + if len(entered) != 0 { + t.Fatalf("step 3: expected entered=[], got %v", entered) + } + if len(left) != 0 { + t.Fatalf("step 3: expected left=[], got %v", left) + } + + // Step 4: User leaves all -> entered=[], left=[gfWork] + entered, left = tracker.Update(userID, []uuid.UUID{}) + if len(entered) != 0 { + t.Fatalf("step 4: expected entered=[], got %v", entered) + } + if len(left) != 1 || left[0] != gfWork { + t.Fatalf("step 4: expected left=[gfWork], got %v", left) + } +} diff --git a/server/internal/model/model.go b/server/internal/model/model.go new file mode 100644 index 0000000..71c3979 --- /dev/null +++ b/server/internal/model/model.go @@ -0,0 +1,63 @@ +package model + +import ( + "time" + + "github.com/google/uuid" +) + +type User struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + DisplayName string `json:"display_name"` + PasswordHash string `json:"-"` + CreatedAt time.Time `json:"created_at"` +} + +type Circle struct { + ID uuid.UUID `json:"id"` + Name string `json:"name"` + InviteCode string `json:"invite_code"` + CreatedBy uuid.UUID `json:"created_by"` + CreatedAt time.Time `json:"created_at"` +} + +type CircleMember struct { + CircleID uuid.UUID `json:"circle_id"` + UserID uuid.UUID `json:"user_id"` + Role string `json:"role"` + JoinedAt time.Time `json:"joined_at"` + DisplayName string `json:"display_name"` + Email string `json:"email"` +} + +type Location struct { + ID int64 `json:"id"` + UserID uuid.UUID `json:"user_id"` + Lat float64 `json:"lat"` + Lng float64 `json:"lng"` + Speed *float32 `json:"speed,omitempty"` + BatteryLevel *int16 `json:"battery_level,omitempty"` + Accuracy *float32 `json:"accuracy,omitempty"` + RecordedAt time.Time `json:"recorded_at"` +} + +type LocationInput struct { + Lat float64 `json:"lat"` + Lng float64 `json:"lng"` + Speed *float32 `json:"speed,omitempty"` + BatteryLevel *int16 `json:"battery_level,omitempty"` + Accuracy *float32 `json:"accuracy,omitempty"` + RecordedAt time.Time `json:"recorded_at"` +} + +type Geofence struct { + ID uuid.UUID `json:"id"` + CircleID uuid.UUID `json:"circle_id"` + Name string `json:"name"` + Lat float64 `json:"lat"` + Lng float64 `json:"lng"` + RadiusMeters float32 `json:"radius_meters"` + CreatedBy uuid.UUID `json:"created_by"` + CreatedAt time.Time `json:"created_at"` +} diff --git a/server/internal/notify/fcm.go b/server/internal/notify/fcm.go new file mode 100644 index 0000000..7a871f5 --- /dev/null +++ b/server/internal/notify/fcm.go @@ -0,0 +1,55 @@ +package notify + +import ( + "context" + "fmt" + + firebase "firebase.google.com/go/v4" + "firebase.google.com/go/v4/messaging" + "google.golang.org/api/option" +) + +// FCMSender uses Firebase Admin SDK. +type FCMSender struct { + client *messaging.Client +} + +func NewFCMSender(ctx context.Context, credentialsFile string) (*FCMSender, error) { + app, err := firebase.NewApp(ctx, nil, option.WithCredentialsFile(credentialsFile)) + if err != nil { + return nil, fmt.Errorf("fcm: firebase.NewApp: %w", err) + } + client, err := app.Messaging(ctx) + if err != nil { + return nil, fmt.Errorf("fcm: app.Messaging: %w", err) + } + return &FCMSender{client: client}, nil +} + +func (f *FCMSender) Send(ctx context.Context, msg Message) error { + _, err := f.client.Send(ctx, &messaging.Message{ + Token: msg.Token, + Notification: &messaging.Notification{ + Title: msg.Title, + Body: msg.Body, + }, + Android: &messaging.AndroidConfig{ + Priority: "high", + Notification: &messaging.AndroidNotification{ + ChannelID: "place_alerts", + }, + }, + }) + if err != nil { + return fmt.Errorf("fcm: send: %w", err) + } + return nil +} + +// NoopSender is used when FCM credentials are not configured. +type NoopSender struct{} + +func (n NoopSender) Send(ctx context.Context, msg Message) error { + fmt.Printf("noop fcm: to=%s title=%q body=%q\n", msg.Token, msg.Title, msg.Body) + return nil +} diff --git a/server/internal/notify/notify.go b/server/internal/notify/notify.go new file mode 100644 index 0000000..3494954 --- /dev/null +++ b/server/internal/notify/notify.go @@ -0,0 +1,53 @@ +package notify + +import ( + "context" + "fmt" + "log" +) + +type Message struct { + Token string + Title string + Body string +} + +type Sender interface { + Send(ctx context.Context, msg Message) error +} + +type Notifier struct { + sender Sender +} + +func NewNotifier(sender Sender) *Notifier { + return &Notifier{sender: sender} +} + +func (n *Notifier) GeofenceEnter(ctx context.Context, userName, placeName string, fcmTokens []string) { + body := fmt.Sprintf("%s arrived at %s", userName, placeName) + for _, token := range fcmTokens { + msg := Message{ + Token: token, + Title: "Location Update", + Body: body, + } + if err := n.sender.Send(ctx, msg); err != nil { + log.Printf("notify: GeofenceEnter send error for token %s: %v", token, err) + } + } +} + +func (n *Notifier) GeofenceLeave(ctx context.Context, userName, placeName string, fcmTokens []string) { + body := fmt.Sprintf("%s left %s", userName, placeName) + for _, token := range fcmTokens { + msg := Message{ + Token: token, + Title: "Location Update", + Body: body, + } + if err := n.sender.Send(ctx, msg); err != nil { + log.Printf("notify: GeofenceLeave send error for token %s: %v", token, err) + } + } +} diff --git a/server/internal/notify/notify_test.go b/server/internal/notify/notify_test.go new file mode 100644 index 0000000..821f39e --- /dev/null +++ b/server/internal/notify/notify_test.go @@ -0,0 +1,59 @@ +package notify + +import ( + "context" + "testing" +) + +type mockSender struct { + messages []Message +} + +func (m *mockSender) Send(ctx context.Context, msg Message) error { + m.messages = append(m.messages, msg) + return nil +} + +func TestGeofenceEnter(t *testing.T) { + sender := &mockSender{} + notifier := NewNotifier(sender) + + ctx := context.Background() + notifier.GeofenceEnter(ctx, "Alice", "Home", []string{"token-1"}) + + if len(sender.messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(sender.messages)) + } + msg := sender.messages[0] + if msg.Token != "token-1" { + t.Errorf("expected token 'token-1', got %q", msg.Token) + } + expectedBody := "Alice arrived at Home" + if msg.Body != expectedBody { + t.Errorf("expected body %q, got %q", expectedBody, msg.Body) + } +} + +func TestGeofenceLeave(t *testing.T) { + sender := &mockSender{} + notifier := NewNotifier(sender) + + ctx := context.Background() + notifier.GeofenceLeave(ctx, "Alice", "Work", []string{"token-1", "token-2"}) + + if len(sender.messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(sender.messages)) + } + expectedBody := "Alice left Work" + for i, msg := range sender.messages { + if msg.Body != expectedBody { + t.Errorf("message %d: expected body %q, got %q", i, expectedBody, msg.Body) + } + } + if sender.messages[0].Token != "token-1" { + t.Errorf("expected first token 'token-1', got %q", sender.messages[0].Token) + } + if sender.messages[1].Token != "token-2" { + t.Errorf("expected second token 'token-2', got %q", sender.messages[1].Token) + } +} diff --git a/server/internal/store/circles.go b/server/internal/store/circles.go new file mode 100644 index 0000000..67b04dc --- /dev/null +++ b/server/internal/store/circles.go @@ -0,0 +1,138 @@ +package store + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/model" +) + +func generateInviteCode() (string, error) { + b := make([]byte, 6) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate invite code: %w", err) + } + return hex.EncodeToString(b), nil +} + +func (s *Store) CreateCircle(ctx context.Context, name string, createdBy uuid.UUID) (*model.Circle, error) { + code, err := generateInviteCode() + if err != nil { + return nil, err + } + + tx, err := s.pool.Begin(ctx) + if err != nil { + return nil, fmt.Errorf("begin transaction: %w", err) + } + defer tx.Rollback(ctx) + + var c model.Circle + err = tx.QueryRow(ctx, + `INSERT INTO circles (name, invite_code, created_by) + VALUES ($1, $2, $3) + RETURNING id, name, invite_code, created_by, created_at`, + name, code, createdBy, + ).Scan(&c.ID, &c.Name, &c.InviteCode, &c.CreatedBy, &c.CreatedAt) + if err != nil { + return nil, fmt.Errorf("insert circle: %w", err) + } + + _, err = tx.Exec(ctx, + `INSERT INTO circle_members (circle_id, user_id, role) + VALUES ($1, $2, 'admin')`, + c.ID, createdBy, + ) + if err != nil { + return nil, fmt.Errorf("add creator as admin: %w", err) + } + + if err := tx.Commit(ctx); err != nil { + return nil, fmt.Errorf("commit create circle: %w", err) + } + + return &c, nil +} + +func (s *Store) GetCircleByInviteCode(ctx context.Context, code string) (*model.Circle, error) { + var c model.Circle + err := s.pool.QueryRow(ctx, + `SELECT id, name, invite_code, created_by, created_at + FROM circles WHERE invite_code = $1`, + code, + ).Scan(&c.ID, &c.Name, &c.InviteCode, &c.CreatedBy, &c.CreatedAt) + if err != nil { + return nil, fmt.Errorf("get circle by invite code: %w", err) + } + return &c, nil +} + +func (s *Store) AddMember(ctx context.Context, circleID, userID uuid.UUID, role string) error { + _, err := s.pool.Exec(ctx, + `INSERT INTO circle_members (circle_id, user_id, role) + VALUES ($1, $2, $3) + ON CONFLICT DO NOTHING`, + circleID, userID, role, + ) + if err != nil { + return fmt.Errorf("add member: %w", err) + } + return nil +} + +func (s *Store) GetMembers(ctx context.Context, circleID uuid.UUID) ([]model.CircleMember, error) { + rows, err := s.pool.Query(ctx, + `SELECT cm.circle_id, cm.user_id, cm.role, cm.joined_at, u.display_name, u.email + FROM circle_members cm + JOIN users u ON u.id = cm.user_id + WHERE cm.circle_id = $1`, + circleID, + ) + if err != nil { + return nil, fmt.Errorf("get members: %w", err) + } + defer rows.Close() + + var members []model.CircleMember + for rows.Next() { + var m model.CircleMember + if err := rows.Scan(&m.CircleID, &m.UserID, &m.Role, &m.JoinedAt, &m.DisplayName, &m.Email); err != nil { + return nil, fmt.Errorf("scan member: %w", err) + } + members = append(members, m) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate members: %w", err) + } + return members, nil +} + +func (s *Store) GetUserCircles(ctx context.Context, userID uuid.UUID) ([]model.Circle, error) { + rows, err := s.pool.Query(ctx, + `SELECT c.id, c.name, c.invite_code, c.created_by, c.created_at + FROM circles c + JOIN circle_members cm ON cm.circle_id = c.id + WHERE cm.user_id = $1`, + userID, + ) + if err != nil { + return nil, fmt.Errorf("get user circles: %w", err) + } + defer rows.Close() + + var circles []model.Circle + for rows.Next() { + var c model.Circle + if err := rows.Scan(&c.ID, &c.Name, &c.InviteCode, &c.CreatedBy, &c.CreatedAt); err != nil { + return nil, fmt.Errorf("scan circle: %w", err) + } + circles = append(circles, c) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate circles: %w", err) + } + return circles, nil +} diff --git a/server/internal/store/fcm_tokens.go b/server/internal/store/fcm_tokens.go new file mode 100644 index 0000000..d6c487a --- /dev/null +++ b/server/internal/store/fcm_tokens.go @@ -0,0 +1,47 @@ +package store + +import ( + "context" + "fmt" + + "github.com/google/uuid" +) + +func (s *Store) UpsertFCMToken(ctx context.Context, userID uuid.UUID, token string) error { + _, err := s.pool.Exec(ctx, + `INSERT INTO fcm_tokens (user_id, token) + VALUES ($1, $2) + ON CONFLICT (user_id) DO UPDATE SET token = $2, updated_at = now()`, + userID, token, + ) + if err != nil { + return fmt.Errorf("upsert fcm token: %w", err) + } + return nil +} + +func (s *Store) GetFCMTokensForCircle(ctx context.Context, circleID uuid.UUID, excludeUserID uuid.UUID) ([]string, error) { + rows, err := s.pool.Query(ctx, + `SELECT ft.token FROM fcm_tokens ft + JOIN circle_members cm ON cm.user_id = ft.user_id + WHERE cm.circle_id = $1 AND ft.user_id != $2`, + circleID, excludeUserID, + ) + if err != nil { + return nil, fmt.Errorf("get fcm tokens for circle: %w", err) + } + defer rows.Close() + + var tokens []string + for rows.Next() { + var token string + if err := rows.Scan(&token); err != nil { + return nil, fmt.Errorf("scan fcm token: %w", err) + } + tokens = append(tokens, token) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate fcm tokens: %w", err) + } + return tokens, nil +} diff --git a/server/internal/store/geofences.go b/server/internal/store/geofences.go new file mode 100644 index 0000000..d3e9214 --- /dev/null +++ b/server/internal/store/geofences.go @@ -0,0 +1,97 @@ +package store + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/model" +) + +func (s *Store) CreateGeofence(ctx context.Context, circleID uuid.UUID, name string, lat, lng float64, radiusMeters float32, createdBy uuid.UUID) (*model.Geofence, error) { + var g model.Geofence + err := s.pool.QueryRow(ctx, + `INSERT INTO geofences (circle_id, name, center, radius_meters, created_by) + VALUES ($1, $2, ST_SetSRID(ST_MakePoint($3, $4), 4326)::geography, $5, $6) + RETURNING id, circle_id, name, ST_Y(center::geometry), ST_X(center::geometry), radius_meters, created_by, created_at`, + circleID, name, lng, lat, radiusMeters, createdBy, + ).Scan(&g.ID, &g.CircleID, &g.Name, &g.Lat, &g.Lng, &g.RadiusMeters, &g.CreatedBy, &g.CreatedAt) + if err != nil { + return nil, fmt.Errorf("create geofence: %w", err) + } + return &g, nil +} + +func (s *Store) GetGeofences(ctx context.Context, circleID uuid.UUID) ([]model.Geofence, error) { + rows, err := s.pool.Query(ctx, + `SELECT id, circle_id, name, ST_Y(center::geometry), ST_X(center::geometry), radius_meters, created_by, created_at + FROM geofences WHERE circle_id = $1`, + circleID, + ) + if err != nil { + return nil, fmt.Errorf("get geofences: %w", err) + } + defer rows.Close() + + var geofences []model.Geofence + for rows.Next() { + var g model.Geofence + if err := rows.Scan(&g.ID, &g.CircleID, &g.Name, &g.Lat, &g.Lng, &g.RadiusMeters, &g.CreatedBy, &g.CreatedAt); err != nil { + return nil, fmt.Errorf("scan geofence: %w", err) + } + geofences = append(geofences, g) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate geofences: %w", err) + } + return geofences, nil +} + +func (s *Store) UpdateGeofence(ctx context.Context, id uuid.UUID, name string, lat, lng float64, radiusMeters float32) (*model.Geofence, error) { + var g model.Geofence + err := s.pool.QueryRow(ctx, + `UPDATE geofences + SET name = $2, center = ST_SetSRID(ST_MakePoint($3, $4), 4326)::geography, radius_meters = $5 + WHERE id = $1 + RETURNING id, circle_id, name, ST_Y(center::geometry), ST_X(center::geometry), radius_meters, created_by, created_at`, + id, name, lng, lat, radiusMeters, + ).Scan(&g.ID, &g.CircleID, &g.Name, &g.Lat, &g.Lng, &g.RadiusMeters, &g.CreatedBy, &g.CreatedAt) + if err != nil { + return nil, fmt.Errorf("update geofence: %w", err) + } + return &g, nil +} + +func (s *Store) DeleteGeofence(ctx context.Context, id uuid.UUID) error { + _, err := s.pool.Exec(ctx, `DELETE FROM geofences WHERE id = $1`, id) + if err != nil { + return fmt.Errorf("delete geofence: %w", err) + } + return nil +} + +func (s *Store) FindContainingGeofences(ctx context.Context, circleID uuid.UUID, lat, lng float64) ([]uuid.UUID, error) { + rows, err := s.pool.Query(ctx, + `SELECT id FROM geofences + WHERE circle_id = $1 + AND ST_DWithin(center, ST_SetSRID(ST_MakePoint($2, $3), 4326)::geography, radius_meters)`, + circleID, lng, lat, + ) + if err != nil { + return nil, fmt.Errorf("find containing geofences: %w", err) + } + defer rows.Close() + + var ids []uuid.UUID + for rows.Next() { + var id uuid.UUID + if err := rows.Scan(&id); err != nil { + return nil, fmt.Errorf("scan geofence id: %w", err) + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate geofence ids: %w", err) + } + return ids, nil +} diff --git a/server/internal/store/locations.go b/server/internal/store/locations.go new file mode 100644 index 0000000..7e7eb7f --- /dev/null +++ b/server/internal/store/locations.go @@ -0,0 +1,127 @@ +package store + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/model" +) + +// InsertLocations bulk-inserts location points using a single INSERT with multiple VALUE tuples. +func (s *Store) InsertLocations(ctx context.Context, userID uuid.UUID, locs []model.LocationInput) error { + if len(locs) == 0 { + return nil + } + + // Build: INSERT INTO locations (user_id, point, speed, battery_level, accuracy, recorded_at) VALUES ... + // ST_MakePoint takes (lng, lat) and we cast to geography. + // Each row occupies 7 params: user_id, lng, lat, speed, battery_level, accuracy, recorded_at + valueStrings := make([]string, 0, len(locs)) + args := make([]any, 0, len(locs)*7) + + for i, loc := range locs { + base := i * 7 + valueStrings = append(valueStrings, fmt.Sprintf( + "($%d, ST_SetSRID(ST_MakePoint($%d, $%d), 4326)::geography, $%d, $%d, $%d, $%d)", + base+1, base+2, base+3, base+4, base+5, base+6, base+7, + )) + args = append(args, userID, loc.Lng, loc.Lat, loc.Speed, loc.BatteryLevel, loc.Accuracy, loc.RecordedAt) + } + + query := "INSERT INTO locations (user_id, point, speed, battery_level, accuracy, recorded_at) VALUES " + + strings.Join(valueStrings, ", ") + + _, err := s.pool.Exec(ctx, query, args...) + if err != nil { + return fmt.Errorf("insert locations: %w", err) + } + return nil +} + +// GetLatestLocations returns the most recent location for each circle member. +func (s *Store) GetLatestLocations(ctx context.Context, circleID uuid.UUID) ([]model.Location, error) { + rows, err := s.pool.Query(ctx, ` + SELECT DISTINCT ON (l.user_id) + l.id, + l.user_id, + ST_Y(l.point::geometry) AS lat, + ST_X(l.point::geometry) AS lng, + l.speed, + l.battery_level, + l.accuracy, + l.recorded_at + FROM locations l + JOIN circle_members cm ON cm.user_id = l.user_id + WHERE cm.circle_id = $1 + ORDER BY l.user_id, l.recorded_at DESC + `, circleID) + if err != nil { + return nil, fmt.Errorf("get latest locations: %w", err) + } + defer rows.Close() + + var locs []model.Location + for rows.Next() { + var l model.Location + if err := rows.Scan(&l.ID, &l.UserID, &l.Lat, &l.Lng, &l.Speed, &l.BatteryLevel, &l.Accuracy, &l.RecordedAt); err != nil { + return nil, fmt.Errorf("scan location: %w", err) + } + locs = append(locs, l) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate locations: %w", err) + } + return locs, nil +} + +// GetHistory returns locations for a single user within a time range, ordered by recorded_at ASC. +func (s *Store) GetHistory(ctx context.Context, userID uuid.UUID, from, to time.Time) ([]model.Location, error) { + rows, err := s.pool.Query(ctx, ` + SELECT + id, + user_id, + ST_Y(point::geometry) AS lat, + ST_X(point::geometry) AS lng, + speed, + battery_level, + accuracy, + recorded_at + FROM locations + WHERE user_id = $1 + AND recorded_at >= $2 + AND recorded_at <= $3 + ORDER BY recorded_at ASC + `, userID, from, to) + if err != nil { + return nil, fmt.Errorf("get history: %w", err) + } + defer rows.Close() + + var locs []model.Location + for rows.Next() { + var l model.Location + if err := rows.Scan(&l.ID, &l.UserID, &l.Lat, &l.Lng, &l.Speed, &l.BatteryLevel, &l.Accuracy, &l.RecordedAt); err != nil { + return nil, fmt.Errorf("scan history location: %w", err) + } + locs = append(locs, l) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate history: %w", err) + } + return locs, nil +} + +// DeleteLocationsOlderThan deletes rows older than N days. Returns count of deleted rows. +func (s *Store) DeleteLocationsOlderThan(ctx context.Context, days int) (int64, error) { + tag, err := s.pool.Exec(ctx, + `DELETE FROM locations WHERE recorded_at < now() - ($1 || ' days')::interval`, + days, + ) + if err != nil { + return 0, fmt.Errorf("delete old locations: %w", err) + } + return tag.RowsAffected(), nil +} diff --git a/server/internal/store/migrations.go b/server/internal/store/migrations.go new file mode 100644 index 0000000..5b6f569 --- /dev/null +++ b/server/internal/store/migrations.go @@ -0,0 +1,82 @@ +package store + +import ( + "context" + "embed" + "fmt" + "log" + "sort" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/pgxpool" +) + +//go:embed migrations/*.sql +var migrationFS embed.FS + +func RunMigrations(ctx context.Context, pool *pgxpool.Pool) error { + _, err := pool.Exec(ctx, ` + CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at TIMESTAMPTZ NOT NULL DEFAULT now() + ) + `) + if err != nil { + return fmt.Errorf("create schema_migrations: %w", err) + } + + entries, err := migrationFS.ReadDir("migrations") + if err != nil { + return fmt.Errorf("read migrations dir: %w", err) + } + + sort.Slice(entries, func(i, j int) bool { + return entries[i].Name() < entries[j].Name() + }) + + for _, entry := range entries { + name := entry.Name() + version, err := strconv.Atoi(strings.Split(name, "_")[0]) + if err != nil { + return fmt.Errorf("parse version from %s: %w", name, err) + } + + var exists bool + err = pool.QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version=$1)", version).Scan(&exists) + if err != nil { + return fmt.Errorf("check migration %d: %w", version, err) + } + if exists { + continue + } + + sql, err := migrationFS.ReadFile("migrations/" + name) + if err != nil { + return fmt.Errorf("read migration %s: %w", name, err) + } + + tx, err := pool.Begin(ctx) + if err != nil { + return fmt.Errorf("begin tx for migration %d: %w", version, err) + } + + if _, err := tx.Exec(ctx, string(sql)); err != nil { + tx.Rollback(ctx) + return fmt.Errorf("run migration %d: %w", version, err) + } + + if _, err := tx.Exec(ctx, "INSERT INTO schema_migrations (version) VALUES ($1)", version); err != nil { + tx.Rollback(ctx) + return fmt.Errorf("record migration %d: %w", version, err) + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("commit migration %d: %w", version, err) + } + + log.Printf("applied migration %03d: %s", version, name) + } + + return nil +} diff --git a/server/internal/store/migrations/001_initial.sql b/server/internal/store/migrations/001_initial.sql new file mode 100644 index 0000000..97cf055 --- /dev/null +++ b/server/internal/store/migrations/001_initial.sql @@ -0,0 +1,50 @@ +CREATE EXTENSION IF NOT EXISTS postgis; +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; + +CREATE TABLE users ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + email TEXT UNIQUE NOT NULL, + display_name TEXT NOT NULL, + password_hash TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE circles ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + name TEXT NOT NULL, + invite_code TEXT UNIQUE NOT NULL, + created_by UUID NOT NULL REFERENCES users(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE TABLE circle_members ( + circle_id UUID NOT NULL REFERENCES circles(id) ON DELETE CASCADE, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + role TEXT NOT NULL DEFAULT 'member', + joined_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (circle_id, user_id) +); + +CREATE TABLE locations ( + id BIGSERIAL PRIMARY KEY, + user_id UUID NOT NULL REFERENCES users(id), + point GEOGRAPHY(Point, 4326) NOT NULL, + speed REAL, + battery_level SMALLINT, + accuracy REAL, + recorded_at TIMESTAMPTZ NOT NULL +); + +CREATE INDEX idx_locations_user_time ON locations (user_id, recorded_at DESC); +CREATE INDEX idx_locations_point ON locations USING GIST (point); +CREATE INDEX idx_locations_recorded_at ON locations USING BRIN (recorded_at); + +CREATE TABLE geofences ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + circle_id UUID NOT NULL REFERENCES circles(id) ON DELETE CASCADE, + name TEXT NOT NULL, + center GEOGRAPHY(Point, 4326) NOT NULL, + radius_meters REAL NOT NULL, + created_by UUID NOT NULL REFERENCES users(id), + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); diff --git a/server/internal/store/migrations/002_fcm_tokens.sql b/server/internal/store/migrations/002_fcm_tokens.sql new file mode 100644 index 0000000..2a44fa4 --- /dev/null +++ b/server/internal/store/migrations/002_fcm_tokens.sql @@ -0,0 +1,6 @@ +CREATE TABLE fcm_tokens ( + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + token TEXT NOT NULL, + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (user_id) +); diff --git a/server/internal/store/store.go b/server/internal/store/store.go new file mode 100644 index 0000000..9fa92a4 --- /dev/null +++ b/server/internal/store/store.go @@ -0,0 +1,33 @@ +package store + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5/pgxpool" +) + +type Store struct { + pool *pgxpool.Pool +} + +func New(ctx context.Context, databaseURL string) (*Store, error) { + pool, err := pgxpool.New(ctx, databaseURL) + if err != nil { + return nil, fmt.Errorf("connect to database: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + return nil, fmt.Errorf("ping database: %w", err) + } + + if err := RunMigrations(ctx, pool); err != nil { + return nil, fmt.Errorf("run migrations: %w", err) + } + + return &Store{pool: pool}, nil +} + +func (s *Store) Close() { + s.pool.Close() +} diff --git a/server/internal/store/store_test.go b/server/internal/store/store_test.go new file mode 100644 index 0000000..50ee6e0 --- /dev/null +++ b/server/internal/store/store_test.go @@ -0,0 +1,353 @@ +package store_test + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/model" + "github.com/nschatz/tracker/server/internal/store" +) + +func testStore(t *testing.T) *store.Store { + t.Helper() + url := os.Getenv("TEST_DATABASE_URL") + if url == "" { + url = "postgres://tracker:tracker@localhost:5432/tracker?sslmode=disable" + } + ctx := context.Background() + s, err := store.New(ctx, url) + if err != nil { + t.Skipf("skipping integration test: cannot connect to database: %v", err) + } + t.Cleanup(func() { s.Close() }) + return s +} + +func uniqueEmail(prefix string) string { + return fmt.Sprintf("%s+%s@example.com", prefix, uuid.New().String()) +} + +func TestCreateAndGetUser(t *testing.T) { + s := testStore(t) + ctx := context.Background() + + email := uniqueEmail("testuser") + displayName := "Test User" + passwordHash := "$2a$10$fakehashfortest" + + created, err := s.CreateUser(ctx, email, displayName, passwordHash) + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + if created.Email != email { + t.Errorf("expected email %q, got %q", email, created.Email) + } + if created.DisplayName != displayName { + t.Errorf("expected display_name %q, got %q", displayName, created.DisplayName) + } + if created.PasswordHash != passwordHash { + t.Errorf("expected password_hash to match") + } + if created.ID.String() == "" { + t.Errorf("expected non-empty ID") + } + + byEmail, err := s.GetUserByEmail(ctx, email) + if err != nil { + t.Fatalf("GetUserByEmail: %v", err) + } + if byEmail.ID != created.ID { + t.Errorf("GetUserByEmail: ID mismatch: got %v, want %v", byEmail.ID, created.ID) + } + + byID, err := s.GetUserByID(ctx, created.ID) + if err != nil { + t.Fatalf("GetUserByID: %v", err) + } + if byID.Email != email { + t.Errorf("GetUserByID: email mismatch: got %q, want %q", byID.Email, email) + } +} + +func TestCreateCircleAndJoin(t *testing.T) { + s := testStore(t) + ctx := context.Background() + + // Create circle owner + owner, err := s.CreateUser(ctx, uniqueEmail("owner"), "Owner", "hash") + if err != nil { + t.Fatalf("CreateUser (owner): %v", err) + } + + // Create circle + circle, err := s.CreateCircle(ctx, "Test Circle", owner.ID) + if err != nil { + t.Fatalf("CreateCircle: %v", err) + } + if circle.Name != "Test Circle" { + t.Errorf("expected circle name %q, got %q", "Test Circle", circle.Name) + } + if circle.InviteCode == "" { + t.Error("expected non-empty invite code") + } + if len(circle.InviteCode) != 12 { + t.Errorf("expected invite code length 12 (6 bytes hex), got %d", len(circle.InviteCode)) + } + if circle.CreatedBy != owner.ID { + t.Errorf("expected created_by %v, got %v", owner.ID, circle.CreatedBy) + } + + // Owner should be an admin member already + members, err := s.GetMembers(ctx, circle.ID) + if err != nil { + t.Fatalf("GetMembers: %v", err) + } + if len(members) != 1 { + t.Fatalf("expected 1 member after create, got %d", len(members)) + } + if members[0].UserID != owner.ID { + t.Errorf("expected owner as first member") + } + if members[0].Role != "admin" { + t.Errorf("expected role 'admin', got %q", members[0].Role) + } + + // Create a second user and add them + joiner, err := s.CreateUser(ctx, uniqueEmail("joiner"), "Joiner", "hash") + if err != nil { + t.Fatalf("CreateUser (joiner): %v", err) + } + + if err := s.AddMember(ctx, circle.ID, joiner.ID, "member"); err != nil { + t.Fatalf("AddMember: %v", err) + } + + // AddMember with ON CONFLICT DO NOTHING should be idempotent + if err := s.AddMember(ctx, circle.ID, joiner.ID, "member"); err != nil { + t.Fatalf("AddMember (duplicate) should not fail: %v", err) + } + + members, err = s.GetMembers(ctx, circle.ID) + if err != nil { + t.Fatalf("GetMembers after join: %v", err) + } + if len(members) != 2 { + t.Errorf("expected 2 members, got %d", len(members)) + } + + // Look up circle by invite code + found, err := s.GetCircleByInviteCode(ctx, circle.InviteCode) + if err != nil { + t.Fatalf("GetCircleByInviteCode: %v", err) + } + if found.ID != circle.ID { + t.Errorf("GetCircleByInviteCode: ID mismatch") + } + + // GetUserCircles for joiner should include the circle + circles, err := s.GetUserCircles(ctx, joiner.ID) + if err != nil { + t.Fatalf("GetUserCircles: %v", err) + } + if len(circles) != 1 { + t.Errorf("expected 1 circle for joiner, got %d", len(circles)) + } + if circles[0].ID != circle.ID { + t.Errorf("GetUserCircles: wrong circle returned") + } +} + +func TestGeofenceCRUD(t *testing.T) { + s := testStore(t) + ctx := context.Background() + + owner, err := s.CreateUser(ctx, uniqueEmail("geo-owner"), "Geo Owner", "hash") + if err != nil { + t.Fatalf("CreateUser: %v", err) + } + circle, err := s.CreateCircle(ctx, "Geo Circle", owner.ID) + if err != nil { + t.Fatalf("CreateCircle: %v", err) + } + + // Create a geofence centered on Times Square, NYC + lat, lng := 40.7580, -73.9855 + radius := float32(100.0) // 100 meters + + gf, err := s.CreateGeofence(ctx, circle.ID, "Times Square", lat, lng, radius, owner.ID) + if err != nil { + t.Fatalf("CreateGeofence: %v", err) + } + if gf.ID == uuid.Nil { + t.Error("expected non-nil geofence ID") + } + if gf.Name != "Times Square" { + t.Errorf("expected name 'Times Square', got %q", gf.Name) + } + if gf.CircleID != circle.ID { + t.Errorf("expected circle_id %v, got %v", circle.ID, gf.CircleID) + } + // Allow small floating point tolerance + if diff := gf.Lat - lat; diff > 0.0001 || diff < -0.0001 { + t.Errorf("expected lat ~%v, got %v", lat, gf.Lat) + } + if diff := gf.Lng - lng; diff > 0.0001 || diff < -0.0001 { + t.Errorf("expected lng ~%v, got %v", lng, gf.Lng) + } + if gf.RadiusMeters != radius { + t.Errorf("expected radius %v, got %v", radius, gf.RadiusMeters) + } + + // List geofences + list, err := s.GetGeofences(ctx, circle.ID) + if err != nil { + t.Fatalf("GetGeofences: %v", err) + } + if len(list) != 1 { + t.Fatalf("expected 1 geofence, got %d", len(list)) + } + if list[0].ID != gf.ID { + t.Errorf("GetGeofences: wrong ID returned") + } + + // Update geofence + newLat, newLng := 40.7589, -73.9851 + newRadius := float32(200.0) + updated, err := s.UpdateGeofence(ctx, gf.ID, "Times Square Updated", newLat, newLng, newRadius) + if err != nil { + t.Fatalf("UpdateGeofence: %v", err) + } + if updated.Name != "Times Square Updated" { + t.Errorf("expected updated name, got %q", updated.Name) + } + if updated.RadiusMeters != newRadius { + t.Errorf("expected updated radius %v, got %v", newRadius, updated.RadiusMeters) + } + + // FindContainingGeofences — point inside (same center, radius is 200m) + inside, err := s.FindContainingGeofences(ctx, circle.ID, newLat, newLng) + if err != nil { + t.Fatalf("FindContainingGeofences (inside): %v", err) + } + if len(inside) != 1 { + t.Fatalf("expected 1 containing geofence for inside point, got %d", len(inside)) + } + if inside[0] != gf.ID { + t.Errorf("FindContainingGeofences: wrong geofence returned") + } + + // FindContainingGeofences — point far outside (Los Angeles) + outside, err := s.FindContainingGeofences(ctx, circle.ID, 34.0522, -118.2437) + if err != nil { + t.Fatalf("FindContainingGeofences (outside): %v", err) + } + if len(outside) != 0 { + t.Errorf("expected 0 containing geofences for outside point, got %d", len(outside)) + } + + // Delete geofence + if err := s.DeleteGeofence(ctx, gf.ID); err != nil { + t.Fatalf("DeleteGeofence: %v", err) + } + + // Verify deleted + list, err = s.GetGeofences(ctx, circle.ID) + if err != nil { + t.Fatalf("GetGeofences after delete: %v", err) + } + if len(list) != 0 { + t.Errorf("expected 0 geofences after delete, got %d", len(list)) + } +} + +func TestInsertAndQueryLocations(t *testing.T) { + s := testStore(t) + ctx := context.Background() + + // Create owner and circle + owner, err := s.CreateUser(ctx, uniqueEmail("loc-owner"), "Loc Owner", "hash") + if err != nil { + t.Fatalf("CreateUser (owner): %v", err) + } + circle, err := s.CreateCircle(ctx, "Loc Circle", owner.ID) + if err != nil { + t.Fatalf("CreateCircle: %v", err) + } + + // Create a member and add them to the circle + member, err := s.CreateUser(ctx, uniqueEmail("loc-member"), "Loc Member", "hash") + if err != nil { + t.Fatalf("CreateUser (member): %v", err) + } + if err := s.AddMember(ctx, circle.ID, member.ID, "member"); err != nil { + t.Fatalf("AddMember: %v", err) + } + + // Insert 3 location points at different times + now := time.Now().UTC().Truncate(time.Second) + speed := float32(5.0) + battery := int16(80) + accuracy := float32(10.0) + + locs := []model.LocationInput{ + {Lat: 40.7128, Lng: -74.0060, Speed: &speed, BatteryLevel: &battery, Accuracy: &accuracy, RecordedAt: now.Add(-2 * time.Minute)}, + {Lat: 40.7130, Lng: -74.0058, RecordedAt: now.Add(-1 * time.Minute)}, + {Lat: 40.7135, Lng: -74.0055, RecordedAt: now}, + } + + if err := s.InsertLocations(ctx, member.ID, locs); err != nil { + t.Fatalf("InsertLocations: %v", err) + } + + // GetLatestLocations should return newest point for the member + latest, err := s.GetLatestLocations(ctx, circle.ID) + if err != nil { + t.Fatalf("GetLatestLocations: %v", err) + } + + // Find the member's entry in results + var memberLoc *model.Location + for i := range latest { + if latest[i].UserID == member.ID { + memberLoc = &latest[i] + break + } + } + if memberLoc == nil { + t.Fatal("GetLatestLocations: member not found in results") + } + + // Should be the newest point (index 2) + if !memberLoc.RecordedAt.Equal(now) { + t.Errorf("GetLatestLocations: expected newest point at %v, got %v", now, memberLoc.RecordedAt) + } + if memberLoc.UserID != member.ID { + t.Errorf("GetLatestLocations: wrong user_id: got %v, want %v", memberLoc.UserID, member.ID) + } + + // GetHistory should return all 3 points + history, err := s.GetHistory(ctx, member.ID, now.Add(-10*time.Minute), now.Add(time.Minute)) + if err != nil { + t.Fatalf("GetHistory: %v", err) + } + if len(history) != 3 { + t.Fatalf("GetHistory: expected 3 points, got %d", len(history)) + } + // Verify ordered ASC + if !history[0].RecordedAt.Before(history[1].RecordedAt) { + t.Errorf("GetHistory: expected ascending order, got %v then %v", history[0].RecordedAt, history[1].RecordedAt) + } + if !history[1].RecordedAt.Before(history[2].RecordedAt) { + t.Errorf("GetHistory: expected ascending order, got %v then %v", history[1].RecordedAt, history[2].RecordedAt) + } + // Verify all belong to the member + for _, h := range history { + if h.UserID != member.ID { + t.Errorf("GetHistory: unexpected user_id %v", h.UserID) + } + } +} diff --git a/server/internal/store/users.go b/server/internal/store/users.go new file mode 100644 index 0000000..ac2035d --- /dev/null +++ b/server/internal/store/users.go @@ -0,0 +1,49 @@ +package store + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/model" +) + +func (s *Store) CreateUser(ctx context.Context, email, displayName, passwordHash string) (*model.User, error) { + var u model.User + err := s.pool.QueryRow(ctx, + `INSERT INTO users (email, display_name, password_hash) + VALUES ($1, $2, $3) + RETURNING id, email, display_name, password_hash, created_at`, + email, displayName, passwordHash, + ).Scan(&u.ID, &u.Email, &u.DisplayName, &u.PasswordHash, &u.CreatedAt) + if err != nil { + return nil, fmt.Errorf("create user: %w", err) + } + return &u, nil +} + +func (s *Store) GetUserByEmail(ctx context.Context, email string) (*model.User, error) { + var u model.User + err := s.pool.QueryRow(ctx, + `SELECT id, email, display_name, password_hash, created_at + FROM users WHERE email = $1`, + email, + ).Scan(&u.ID, &u.Email, &u.DisplayName, &u.PasswordHash, &u.CreatedAt) + if err != nil { + return nil, fmt.Errorf("get user by email: %w", err) + } + return &u, nil +} + +func (s *Store) GetUserByID(ctx context.Context, id uuid.UUID) (*model.User, error) { + var u model.User + err := s.pool.QueryRow(ctx, + `SELECT id, email, display_name, password_hash, created_at + FROM users WHERE id = $1`, + id, + ).Scan(&u.ID, &u.Email, &u.DisplayName, &u.PasswordHash, &u.CreatedAt) + if err != nil { + return nil, fmt.Errorf("get user by id: %w", err) + } + return &u, nil +} diff --git a/server/internal/ws/hub.go b/server/internal/ws/hub.go new file mode 100644 index 0000000..23e8925 --- /dev/null +++ b/server/internal/ws/hub.go @@ -0,0 +1,122 @@ +package ws + +import ( + "context" + "encoding/json" + "log" + "net/http" + "sync" + "time" + + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/model" + "nhooyr.io/websocket" +) + +type client struct { + conn *websocket.Conn + userID uuid.UUID + circleID uuid.UUID + cancel context.CancelFunc +} + +type Hub struct { + mu sync.RWMutex + clients map[*client]struct{} + register chan *client + unregister chan *client +} + +func NewHub() *Hub { + return &Hub{ + clients: make(map[*client]struct{}), + register: make(chan *client, 16), + unregister: make(chan *client, 16), + } +} + +func (h *Hub) Run() { + for { + select { + case c := <-h.register: + h.mu.Lock() + h.clients[c] = struct{}{} + h.mu.Unlock() + case c := <-h.unregister: + h.mu.Lock() + delete(h.clients, c) + h.mu.Unlock() + } + } +} + +func (h *Hub) HandleConnect(w http.ResponseWriter, r *http.Request, userID, circleID uuid.UUID) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + InsecureSkipVerify: true, + }) + if err != nil { + log.Printf("ws: accept error: %v", err) + return + } + + ctx, cancel := context.WithCancel(r.Context()) + c := &client{ + conn: conn, + userID: userID, + circleID: circleID, + cancel: cancel, + } + + h.register <- c + + // Read loop: keeps connection alive and detects close + defer func() { + h.unregister <- c + cancel() + conn.Close(websocket.StatusNormalClosure, "") + }() + + for { + _, _, err := conn.Read(ctx) + if err != nil { + return + } + } +} + +func (h *Hub) BroadcastLocation(circleID uuid.UUID, loc model.Location) { + data, err := json.Marshal(loc) + if err != nil { + log.Printf("ws: marshal error: %v", err) + return + } + + h.mu.RLock() + targets := make([]*client, 0) + for c := range h.clients { + if c.circleID == circleID { + targets = append(targets, c) + } + } + h.mu.RUnlock() + + for _, c := range targets { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + err := c.conn.Write(ctx, websocket.MessageText, data) + cancel() + if err != nil { + log.Printf("ws: write error to client %s: %v", c.userID, err) + } + } +} + +func (h *Hub) IsConnected(userID uuid.UUID) bool { + h.mu.RLock() + defer h.mu.RUnlock() + for c := range h.clients { + if c.userID == userID { + return true + } + } + return false +} diff --git a/server/internal/ws/hub_test.go b/server/internal/ws/hub_test.go new file mode 100644 index 0000000..571c93d --- /dev/null +++ b/server/internal/ws/hub_test.go @@ -0,0 +1,80 @@ +package ws_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + "github.com/nschatz/tracker/server/internal/model" + "github.com/nschatz/tracker/server/internal/ws" + "nhooyr.io/websocket" + "nhooyr.io/websocket/wsjson" +) + +func TestHubBroadcast(t *testing.T) { + hub := ws.NewHub() + go hub.Run() + + userID := uuid.New() + circleID := uuid.New() + + // Create httptest.Server that calls hub.HandleConnect for a specific userID/circleID + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hub.HandleConnect(w, r, userID, circleID) + })) + defer srv.Close() + + // Connect via websocket (nhooyr.io/websocket client) + wsURL := "ws" + srv.URL[4:] // replace "http" with "ws" + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, _, err := websocket.Dial(ctx, wsURL, nil) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer conn.Close(websocket.StatusNormalClosure, "") + + // Wait briefly for registration + time.Sleep(50 * time.Millisecond) + + // BroadcastLocation for the same circleID + loc := model.Location{ + ID: 1, + UserID: userID, + Lat: 37.7749, + Lng: -122.4194, + RecordedAt: time.Now().UTC().Truncate(time.Second), + } + hub.BroadcastLocation(circleID, loc) + + // Read the message from the websocket connection + var received model.Location + if err := wsjson.Read(ctx, conn, &received); err != nil { + t.Fatalf("failed to read message: %v", err) + } + + // Verify it's the correct location JSON + if received.UserID != loc.UserID { + t.Errorf("expected userID %s, got %s", loc.UserID, received.UserID) + } + if received.Lat != loc.Lat { + t.Errorf("expected lat %f, got %f", loc.Lat, received.Lat) + } + if received.Lng != loc.Lng { + t.Errorf("expected lng %f, got %f", loc.Lng, received.Lng) + } + + // Verify IsConnected returns true for the connected user + if !hub.IsConnected(userID) { + t.Error("expected user to be connected") + } + + // Verify IsConnected returns false for an unknown user + if hub.IsConnected(uuid.New()) { + t.Error("expected unknown user to not be connected") + } +}