First release of open core
This commit is contained in:
41
pkg/auth/auth.go
Normal file
41
pkg/auth/auth.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"math/rand"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Handler encapsulates all Identity and Access HTTP logic
|
||||
type Handler struct {
|
||||
Store domain.Store
|
||||
}
|
||||
|
||||
// NewHandler creates a new Auth Handler
|
||||
func NewHandler(store domain.Store) *Handler {
|
||||
return &Handler{Store: store}
|
||||
}
|
||||
|
||||
// HashPassword takes a plaintext password, automatically generates a secure salt
|
||||
func HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
// CheckPasswordHash securely compares a plaintext password with a stored bcrypt hash.
|
||||
func CheckPasswordHash(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// GenerateSessionToken creates a cryptographically secure random string
|
||||
func GenerateSessionToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
140
pkg/auth/auth_handlers.go
Normal file
140
pkg/auth/auth_handlers.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const SessionCookieName = "session_token"
|
||||
|
||||
// RegisterRequest represents the JSON payload expected for user registration.
|
||||
type RegisterRequest struct {
|
||||
Email string `json:"email"`
|
||||
FullName string `json:"full_name"`
|
||||
Password string `json:"password"`
|
||||
GlobalRole string `json:"global_role"`
|
||||
}
|
||||
|
||||
// LoginRequest represents the JSON payload expected for user login.
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// HandleRegister processes new user signups.
|
||||
func (h *Handler) HandleRegister(w http.ResponseWriter, r *http.Request) {
|
||||
var req RegisterRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.Store.GetUserCount(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
http.Error(w, "Forbidden: System already initialized. Contact your Sheriff for an account.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
req.GlobalRole = "Sheriff"
|
||||
|
||||
if req.Email == "" || req.Password == "" || req.FullName == "" {
|
||||
http.Error(w, "Missing required fields", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := HashPassword(req.Password)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to hash password", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.Store.CreateUser(r.Context(), req.Email, req.FullName, hashedPassword, req.GlobalRole)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
http.Error(w, "Email already exists", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
http.Error(w, "Failed to create user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
|
||||
// HandleLogin authenticates a user and issues a session cookie.
|
||||
func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var req LoginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.Store.GetUserByEmail(r.Context(), req.Email)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if !CheckPasswordHash(req.Password, user.PasswordHash) {
|
||||
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := GenerateSessionToken()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to generate session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
if err := h.Store.CreateSession(r.Context(), token, user.ID, expiresAt); err != nil {
|
||||
http.Error(w, "Failed to persist session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: token,
|
||||
Expires: expiresAt,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: false, // Set to TRUE in production for HTTPS!
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
|
||||
// HandleLogout destroys the user's session in the database and clears their cookie.
|
||||
func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(SessionCookieName)
|
||||
|
||||
if err == nil && cookie.Value != "" {
|
||||
_ = h.Store.DeleteSession(r.Context(), cookie.Value)
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: SessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: true, // Ensures it's only sent over HTTPS
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"message": "Successfully logged out",
|
||||
})
|
||||
}
|
||||
111
pkg/auth/auth_handlers_test.go
Normal file
111
pkg/auth/auth_handlers_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/datastore"
|
||||
)
|
||||
|
||||
func setupTestAuth(t *testing.T) (*Handler, *sql.DB) {
|
||||
db := datastore.InitDB(":memory:")
|
||||
|
||||
store := datastore.NewSQLiteStore(db)
|
||||
|
||||
h := NewHandler(store)
|
||||
|
||||
return h, db
|
||||
}
|
||||
|
||||
func TestAuthHandlers(t *testing.T) {
|
||||
a, db := setupTestAuth(t)
|
||||
defer db.Close()
|
||||
|
||||
t.Run("Successful Registration", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"email": "admin@RiskRancher.com",
|
||||
"full_name": "Doc Holliday",
|
||||
"password": "SuperSecretPassword123!",
|
||||
"global_role": "Sheriff", // Use a valid role!
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
a.HandleRegister(rr, req)
|
||||
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("Expected 201 Created for registration, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Successful Login Issues Cookie", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"email": "admin@RiskRancher.com",
|
||||
"password": "SuperSecretPassword123!",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
a.HandleLogin(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK for successful login, got %d", rr.Code)
|
||||
}
|
||||
|
||||
cookies := rr.Result().Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Fatalf("Expected a session cookie to be set, but none was found")
|
||||
}
|
||||
if cookies[0].Name != "session_token" {
|
||||
t.Errorf("Expected cookie named 'session_token', got '%s'", cookies[0].Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Failed Login Rejects Access", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"email": "admin@RiskRancher.com",
|
||||
"password": "WrongPassword!",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
a.HandleLogin(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("Expected 401 Unauthorized for wrong password, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleLogout(t *testing.T) {
|
||||
a, db := setupTestAuth(t)
|
||||
defer db.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/logout", nil)
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: SessionCookieName,
|
||||
Value: "fake-session-token-123",
|
||||
}
|
||||
req.AddCookie(cookie)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
a.HandleLogout(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK, got %d. Body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
49
pkg/auth/auth_test.go
Normal file
49
pkg/auth/auth_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPasswordHashing(t *testing.T) {
|
||||
password := "SuperSecretSOCPassword123!"
|
||||
|
||||
hash, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
if hash == password {
|
||||
t.Fatalf("Security failure: Hash matches plain text!")
|
||||
}
|
||||
if len(hash) == 0 {
|
||||
t.Fatalf("Hash is empty")
|
||||
}
|
||||
|
||||
isValid := CheckPasswordHash(password, hash)
|
||||
if !isValid {
|
||||
t.Errorf("Expected valid password to match hash, but it failed")
|
||||
}
|
||||
|
||||
isInvalid := CheckPasswordHash("WrongPassword!", hash)
|
||||
if isInvalid {
|
||||
t.Errorf("Security failure: Incorrect password returned true!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSessionToken(t *testing.T) {
|
||||
|
||||
token1, err1 := GenerateSessionToken()
|
||||
token2, err2 := GenerateSessionToken()
|
||||
|
||||
if err1 != nil || err2 != nil {
|
||||
t.Fatalf("Failed to generate session tokens")
|
||||
}
|
||||
|
||||
if len(token1) < 32 {
|
||||
t.Errorf("Token is too short for security standards: %d chars", len(token1))
|
||||
}
|
||||
|
||||
if token1 == token2 {
|
||||
t.Errorf("CRITICAL: RNG generated the exact same token twice: %s", token1)
|
||||
}
|
||||
}
|
||||
56
pkg/auth/middleware.go
Normal file
56
pkg/auth/middleware.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const UserIDKey contextKey = "user_id"
|
||||
|
||||
func (h *Handler) RequireAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized: Missing session cookie", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := h.Store.GetSession(r.Context(), cookie.Value)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized: Invalid session", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
http.Error(w, "Unauthorized: Session expired", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), UserIDKey, session.UserID)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// RequireUIAuth checks for a valid session and redirects to /login if it fails,
|
||||
func (h *Handler) RequireUIAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := h.Store.GetSession(r.Context(), cookie.Value)
|
||||
if err != nil || session.ExpiresAt.Before(time.Now()) {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), UserIDKey, session.UserID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
61
pkg/auth/middleware_test.go
Normal file
61
pkg/auth/middleware_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRequireAuthMiddleware(t *testing.T) {
|
||||
h, db := setupTestAuth(t)
|
||||
defer db.Close()
|
||||
|
||||
user, err := h.Store.CreateUser(context.Background(), "vip@RiskRancher.com", "Wyatt Earp", "fake_hash", "Sheriff")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to seed test user: %v", err)
|
||||
}
|
||||
|
||||
validToken := "valid_test_token_123"
|
||||
expiresAt := time.Now().Add(1 * time.Hour)
|
||||
err = h.Store.CreateSession(context.Background(), validToken, user.ID, expiresAt)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to seed test session: %v", err)
|
||||
}
|
||||
|
||||
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Welcome to the VIP room"))
|
||||
})
|
||||
protectedHandler := h.RequireAuth(dummyHandler)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cookieName string
|
||||
cookieValue string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"Missing Cookie", "", "", http.StatusUnauthorized},
|
||||
{"Wrong Cookie Name", "wrong_name", validToken, http.StatusUnauthorized},
|
||||
{"Invalid Token", "session_token", "fake_invalid_token", http.StatusUnauthorized},
|
||||
{"Valid Token", "session_token", validToken, http.StatusOK},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
if tt.cookieName != "" {
|
||||
req.AddCookie(&http.Cookie{Name: tt.cookieName, Value: tt.cookieValue})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
protectedHandler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
74
pkg/auth/rbac_middleware.go
Normal file
74
pkg/auth/rbac_middleware.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// RequireRole acts as the checker
|
||||
func (h *Handler) RequireRole(requiredRole string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userIDVal := r.Context().Value(UserIDKey)
|
||||
if userIDVal == nil {
|
||||
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := userIDVal.(int)
|
||||
if !ok {
|
||||
http.Error(w, "Internal Server Error: Invalid user context", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.Store.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
http.Error(w, "Forbidden: User not found", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if user.GlobalRole != requiredRole {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAnyRole allows access if the user has ANY of the provided roles.
|
||||
func (h *Handler) RequireAnyRole(allowedRoles ...string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userIDVal := r.Context().Value(UserIDKey)
|
||||
if userIDVal == nil {
|
||||
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := userIDVal.(int)
|
||||
if !ok {
|
||||
http.Error(w, "Internal Server Error: Invalid user context", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.Store.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
http.Error(w, "Forbidden: User not found", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
for _, role := range allowedRoles {
|
||||
if user.GlobalRole == role {
|
||||
// Match found! Open the door.
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
})
|
||||
}
|
||||
}
|
||||
49
pkg/auth/rbac_middleware_test.go
Normal file
49
pkg/auth/rbac_middleware_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequireRoleMiddleware(t *testing.T) {
|
||||
a, db := setupTestAuth(t)
|
||||
defer db.Close()
|
||||
|
||||
sheriff, _ := a.Store.CreateUser(context.Background(), "sheriff@ranch.com", "Wyatt Earp", "hash", "Sheriff")
|
||||
rangeHand, _ := a.Store.CreateUser(context.Background(), "hand@ranch.com", "Jesse James", "hash", "RangeHand")
|
||||
|
||||
vipHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Welcome to the Manager's Office"))
|
||||
})
|
||||
|
||||
protectedHandler := a.RequireRole("Sheriff")(vipHandler)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int
|
||||
expectedStatus int
|
||||
}{
|
||||
{"Valid Sheriff Access", sheriff.ID, http.StatusOK},
|
||||
{"Denied RangeHand Access", rangeHand.ID, http.StatusForbidden},
|
||||
{"Unknown User", 9999, http.StatusForbidden},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/passwords", nil)
|
||||
|
||||
ctx := context.WithValue(req.Context(), UserIDKey, tt.userID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
protectedHandler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user