First release of open core
This commit is contained in:
147
pkg/adapters/adapters.go
Normal file
147
pkg/adapters/adapters.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
domain2 "epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func (h *Handler) HandleGetAdapters(w http.ResponseWriter, r *http.Request) {
|
||||
adapters, err := h.Store.GetAdapters(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(adapters)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleCreateAdapter(w http.ResponseWriter, r *http.Request) {
|
||||
var adapter domain2.Adapter
|
||||
if err := json.NewDecoder(r.Body).Decode(&adapter); err != nil {
|
||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := h.Store.SaveAdapter(r.Context(), adapter); err != nil {
|
||||
http.Error(w, "Failed to save adapter", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleDeleteAdapter(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := r.PathValue("id")
|
||||
id, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid adapter ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.Store.DeleteAdapter(r.Context(), id); err != nil {
|
||||
http.Error(w, "Failed to delete adapter", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func getJSONValue(data interface{}, path string) interface{} {
|
||||
if path == "" || path == "." {
|
||||
return data // The root IS the array
|
||||
}
|
||||
keys := strings.Split(path, ".")
|
||||
current := data
|
||||
for _, key := range keys {
|
||||
if m, ok := current.(map[string]interface{}); ok {
|
||||
current = m[key]
|
||||
} else {
|
||||
return nil // Path broke
|
||||
}
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
func interfaceToString(val interface{}) string {
|
||||
if val == nil {
|
||||
return ""
|
||||
}
|
||||
if str, ok := val.(string); ok {
|
||||
return str
|
||||
}
|
||||
return "" // Could expand this to handle ints/floats if needed
|
||||
}
|
||||
|
||||
// HandleAdapterIngest dynamically maps deeply nested JSON arrays into Tickets
|
||||
func (h *Handler) HandleAdapterIngest(w http.ResponseWriter, r *http.Request) {
|
||||
adapterName := r.PathValue("name")
|
||||
adapter, err := h.Store.GetAdapterByName(r.Context(), adapterName)
|
||||
if err != nil {
|
||||
http.Error(w, "Adapter not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var rawData interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&rawData); err != nil {
|
||||
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
findingsNode := getJSONValue(rawData, adapter.FindingsPath)
|
||||
findingsArray, ok := findingsNode.([]interface{})
|
||||
if !ok {
|
||||
http.Error(w, "Findings path did not resolve to a JSON array", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
type groupKey struct {
|
||||
Source string
|
||||
Asset string
|
||||
}
|
||||
groupedTickets := make(map[groupKey][]domain2.Ticket)
|
||||
|
||||
for _, item := range findingsArray {
|
||||
finding, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
ticket := domain2.Ticket{
|
||||
Source: adapter.SourceName,
|
||||
Status: "Waiting to be Triaged", // Explicitly set status
|
||||
Title: interfaceToString(finding[adapter.MappingTitle]),
|
||||
AssetIdentifier: interfaceToString(finding[adapter.MappingAsset]),
|
||||
Severity: interfaceToString(finding[adapter.MappingSeverity]),
|
||||
Description: interfaceToString(finding[adapter.MappingDescription]),
|
||||
RecommendedRemediation: interfaceToString(finding[adapter.MappingRemediation]),
|
||||
}
|
||||
|
||||
if ticket.Title != "" && ticket.AssetIdentifier != "" {
|
||||
hashInput := ticket.Source + "|" + ticket.AssetIdentifier + "|" + ticket.Title
|
||||
hash := sha256.Sum256([]byte(hashInput))
|
||||
ticket.DedupeHash = hex.EncodeToString(hash[:])
|
||||
key := groupKey{Source: ticket.Source, Asset: ticket.AssetIdentifier}
|
||||
groupedTickets[key] = append(groupedTickets[key], ticket)
|
||||
}
|
||||
}
|
||||
|
||||
for key, batch := range groupedTickets {
|
||||
err := h.Store.ProcessIngestionBatch(r.Context(), key.Source, key.Asset, batch)
|
||||
if err != nil {
|
||||
log.Printf("🔥 JSON Ingestion Error for Asset %s: %v", key.Asset, err)
|
||||
// 🚀 LOG THE BATCH FAILURE
|
||||
h.Store.LogSync(r.Context(), key.Source, "Failed", len(batch), err.Error())
|
||||
http.Error(w, "Database error processing JSON batch", http.StatusInternalServerError)
|
||||
return
|
||||
} else {
|
||||
// 🚀 LOG THE SUCCESS
|
||||
h.Store.LogSync(r.Context(), key.Source, "Success", len(batch), "")
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
142
pkg/adapters/adapters_test.go
Normal file
142
pkg/adapters/adapters_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/datastore"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func setupTestAdapters(t *testing.T) (*Handler, *sql.DB) {
|
||||
db := datastore.InitDB(":memory:")
|
||||
store := datastore.NewSQLiteStore(db)
|
||||
return NewHandler(store), db
|
||||
}
|
||||
|
||||
func GetVIPCookie(store domain.Store) *http.Cookie {
|
||||
user, err := store.GetUserByEmail(context.Background(), "vip@RiskRancher.com")
|
||||
if err != nil {
|
||||
user, _ = store.CreateUser(context.Background(), "vip@RiskRancher.com", "Test VIP", "hash", "Sheriff")
|
||||
}
|
||||
|
||||
store.CreateSession(context.Background(), "vip_token_999", user.ID, time.Now().Add(1*time.Hour))
|
||||
return &http.Cookie{Name: "session_token", Value: "vip_token_999"}
|
||||
}
|
||||
|
||||
func TestHandleAdapterIngest(t *testing.T) {
|
||||
h, db := setupTestAdapters(t)
|
||||
defer db.Close()
|
||||
|
||||
adapterPayload := []byte(`{"name": "Trivy Test", "source_name": "TrivyScanner", "findings_path": "Results", "mapping_title": "VulnerabilityID", "mapping_asset": "Target", "mapping_severity": "Severity"}`)
|
||||
reqAdapter := httptest.NewRequest(http.MethodPost, "/api/adapters", bytes.NewBuffer(adapterPayload))
|
||||
reqAdapter.AddCookie(GetVIPCookie(h.Store))
|
||||
reqAdapter.Header.Set("Content-Type", "application/json")
|
||||
rrAdapter := httptest.NewRecorder()
|
||||
|
||||
h.HandleCreateAdapter(rrAdapter, reqAdapter)
|
||||
|
||||
payload := []byte(`{"SchemaVersion": 2, "Results": [{"VulnerabilityID": "CVE-1", "Target": "A", "Severity": "HIGH"}]}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/ingest/Trivy%20Test", bytes.NewBuffer(payload))
|
||||
req.AddCookie(GetVIPCookie(h.Store))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
req.SetPathValue("name", "Trivy Test")
|
||||
rr := httptest.NewRecorder()
|
||||
h.HandleAdapterIngest(rr, req)
|
||||
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("Expected 201 Created, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAdapters(t *testing.T) {
|
||||
h, db := setupTestAdapters(t)
|
||||
defer db.Close()
|
||||
|
||||
db.Exec(`INSERT INTO data_adapters (name, source_name, findings_path, mapping_title, mapping_asset, mapping_severity) VALUES ('Trivy Test', 'Trivy', 'Results', 'VulnerabilityID', 'PkgName', 'Severity')`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/adapters", nil)
|
||||
req.AddCookie(GetVIPCookie(h.Store))
|
||||
rr := httptest.NewRecorder()
|
||||
h.HandleGetAdapters(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAdapter(t *testing.T) {
|
||||
h, db := setupTestAdapters(t)
|
||||
defer db.Close()
|
||||
|
||||
payload := []byte(`{"name": "AcmeSec", "source_name": "Acme", "findings_path": "issues", "mapping_title": "t", "mapping_asset": "a", "mapping_severity": "s"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/adapters", bytes.NewBuffer(payload))
|
||||
req.AddCookie(GetVIPCookie(h.Store))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.HandleCreateAdapter(rr, req)
|
||||
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("Expected 201 Created, got %d", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONIngestion(t *testing.T) {
|
||||
h, db := setupTestAdapters(t)
|
||||
defer db.Close()
|
||||
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO data_adapters (
|
||||
id, name, source_name, findings_path,
|
||||
mapping_title, mapping_asset, mapping_severity
|
||||
) VALUES (
|
||||
998, 'NestedScanner', 'DeepScan', 'scan_data.results',
|
||||
'vuln_name', 'target_ip', 'risk_level'
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup nested adapter: %v", err)
|
||||
}
|
||||
|
||||
payload := []byte(`{
|
||||
"metadata": { "version": "1.0" },
|
||||
"scan_data": {
|
||||
"results": [
|
||||
{
|
||||
"vuln_name": "Log4j RCE",
|
||||
"target_ip": "10.0.0.5",
|
||||
"risk_level": "Critical"
|
||||
}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/ingest/NestedScanner", bytes.NewBuffer(payload))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.AddCookie(GetVIPCookie(h.Store))
|
||||
|
||||
req.SetPathValue("name", "NestedScanner")
|
||||
rr := httptest.NewRecorder()
|
||||
h.HandleAdapterIngest(rr, req)
|
||||
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("Expected 201 Created, got %d. Body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
var title, severity string
|
||||
err = db.QueryRow("SELECT title, severity FROM tickets WHERE source = 'DeepScan'").Scan(&title, &severity)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query ingested ticket: %v", err)
|
||||
}
|
||||
|
||||
if title != "Log4j RCE" || severity != "Critical" {
|
||||
t.Errorf("JSON Mapping failed! Expected 'Log4j RCE' / 'Critical', got '%s' / '%s'", title, severity)
|
||||
}
|
||||
}
|
||||
13
pkg/adapters/handler.go
Normal file
13
pkg/adapters/handler.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package adapters
|
||||
|
||||
import (
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Store domain.Store
|
||||
}
|
||||
|
||||
func NewHandler(store domain.Store) *Handler {
|
||||
return &Handler{Store: store}
|
||||
}
|
||||
62
pkg/admin/admin.go
Normal file
62
pkg/admin/admin.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func (h *Handler) HandleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
config, err := h.Store.GetAppConfig(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to fetch configuration", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(config)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleExportState(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := h.Store.ExportSystemState(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to generate system export", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Disposition", "attachment; filename=RiskRancher_export.json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(state); err != nil {
|
||||
// Note: We can't change the HTTP status code here because we've already started streaming,
|
||||
// but we can log the error if the stream breaks.
|
||||
_ = err
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) HandleGetLogs(w http.ResponseWriter, r *http.Request) {
|
||||
filter := r.URL.Query().Get("filter")
|
||||
page, err := strconv.Atoi(r.URL.Query().Get("page"))
|
||||
if err != nil || page < 1 {
|
||||
page = 1
|
||||
}
|
||||
|
||||
limit := 15
|
||||
offset := (page - 1) * limit
|
||||
|
||||
feed, total, err := h.Store.GetPaginatedActivityFeed(r.Context(), filter, limit, offset)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to load logs", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"feed": feed,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
})
|
||||
}
|
||||
192
pkg/admin/admin_handlers.go
Normal file
192
pkg/admin/admin_handlers.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/auth"
|
||||
)
|
||||
|
||||
// PasswordResetRequest is the expected JSON payload
|
||||
type PasswordResetRequest struct {
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
// HandleAdminResetPassword allows a Sheriff to forcefully overwrite a user's password.
|
||||
func (h *Handler) HandleAdminResetPassword(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := r.PathValue("id")
|
||||
userID, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid user ID in URL", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req PasswordResetRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.NewPassword == "" {
|
||||
http.Error(w, "New password cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := auth.HashPassword(req.NewPassword)
|
||||
if err != nil {
|
||||
http.Error(w, "Internal server error during hashing", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.Store.UpdateUserPassword(r.Context(), userID, hashedPassword)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to update user password", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"message": "Password reset successfully",
|
||||
})
|
||||
}
|
||||
|
||||
type RoleUpdateRequest struct {
|
||||
GlobalRole string `json:"global_role"`
|
||||
}
|
||||
|
||||
// HandleUpdateUserRole allows a Sheriff to promote or demote a user.
|
||||
func (h *Handler) HandleUpdateUserRole(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := r.PathValue("id")
|
||||
userID, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid user ID in URL", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
var req RoleUpdateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
validRoles := map[string]bool{
|
||||
"Sheriff": true, "Wrangler": true, "RangeHand": true, "CircuitRider": true, "Magistrate": true,
|
||||
}
|
||||
if !validRoles[req.GlobalRole] {
|
||||
http.Error(w, "Invalid role provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.Store.UpdateUserRole(r.Context(), userID, req.GlobalRole)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to update user role", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"message": "User role updated successfully to " + req.GlobalRole,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleDeactivateUser allows a Sheriff to safely offboard a user.
|
||||
func (h *Handler) HandleDeactivateUser(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := r.PathValue("id")
|
||||
userID, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid user ID in URL", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.Store.DeactivateUserAndReassign(r.Context(), userID)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to deactivate user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"message": "User successfully deactivated and tickets reassigned.",
|
||||
})
|
||||
}
|
||||
|
||||
// CreateUserRequest is the payload the Sheriff sends to invite a new user
|
||||
type CreateUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
FullName string `json:"full_name"`
|
||||
Password string `json:"password"`
|
||||
GlobalRole string `json:"global_role"`
|
||||
}
|
||||
|
||||
// HandleCreateUser allows a Sheriff to manually provision a new user account.
|
||||
func (h *Handler) HandleCreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
var req CreateUserRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Email == "" || req.FullName == "" || req.Password == "" || req.GlobalRole == "" {
|
||||
http.Error(w, "Missing required fields", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
validRoles := map[string]bool{
|
||||
"Sheriff": true, "Wrangler": true, "RangeHand": true, "CircuitRider": true, "Magistrate": true,
|
||||
}
|
||||
if !validRoles[req.GlobalRole] {
|
||||
http.Error(w, "Invalid role provided", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
http.Error(w, "Internal server error during hashing", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.Store.CreateUser(r.Context(), req.Email, req.FullName, hashedPassword, req.GlobalRole)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
http.Error(w, "Email already exists in the system", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
http.Error(w, "Failed to provision user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"message": "User provisioned successfully. Share the temporary password securely.",
|
||||
"id": user.ID,
|
||||
"email": user.Email,
|
||||
"full_name": user.FullName,
|
||||
"global_role": user.GlobalRole,
|
||||
})
|
||||
}
|
||||
|
||||
// HandleGetUsers returns a list of all users in the system for the Sheriff to manage.
|
||||
func (h *Handler) HandleGetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
users, err := h.Store.GetAllUsers(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to fetch user roster", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(users)
|
||||
}
|
||||
|
||||
// HandleGetWranglers returns a clean list of IT users for assignment dropdowns
|
||||
func (h *Handler) HandleGetWranglers(w http.ResponseWriter, r *http.Request) {
|
||||
wranglers, err := h.Store.GetWranglers(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to fetch wranglers", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(wranglers)
|
||||
}
|
||||
69
pkg/admin/admin_lifecycle.go
Normal file
69
pkg/admin/admin_lifecycle.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const CurrentAppVersion = "v1.0.0"
|
||||
|
||||
type UpdateCheckResponse struct {
|
||||
Status string `json:"status"`
|
||||
CurrentVersion string `json:"current_version"`
|
||||
LatestVersion string `json:"latest_version,omitempty"`
|
||||
UpdateAvailable bool `json:"update_available"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// HandleCheckUpdates pings gitea. If air-gapped, it returns manual instructions.
|
||||
func (h *Handler) HandleCheckUpdates(w http.ResponseWriter, r *http.Request) {
|
||||
respPayload := UpdateCheckResponse{
|
||||
CurrentVersion: CurrentAppVersion,
|
||||
}
|
||||
|
||||
client := http.Client{Timeout: 3 * time.Second}
|
||||
|
||||
giteaURL := "https://epigas.gitea.cloud/api/v1/repos/RiskRancher/core/releases/latest"
|
||||
resp, err := client.Get(giteaURL)
|
||||
|
||||
if err != nil || resp.StatusCode != http.StatusOK {
|
||||
respPayload.Status = "offline"
|
||||
respPayload.Message = "No internet connection detected. To update an air-gapped server: Download the latest RiskRancher binary on a connected machine, transfer it via rsync or scp to this server, and restart the service."
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(respPayload)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var ghRelease struct {
|
||||
TagName string `json:"tag_name"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&ghRelease); err == nil {
|
||||
respPayload.Status = "online"
|
||||
respPayload.LatestVersion = ghRelease.TagName
|
||||
respPayload.UpdateAvailable = (ghRelease.TagName != CurrentAppVersion)
|
||||
|
||||
if respPayload.UpdateAvailable {
|
||||
respPayload.Message = "A new version is available! Please trigger a graceful shutdown and swap the binary."
|
||||
} else {
|
||||
respPayload.Message = "You are running the latest version."
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(respPayload)
|
||||
}
|
||||
|
||||
// HandleShutdown signals the application to close connections and exit cleanly
|
||||
func (h *Handler) HandleShutdown(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"message": "Initiating graceful shutdown. The server will exit in 2 seconds..."}`))
|
||||
go func() {
|
||||
time.Sleep(2 * time.Second)
|
||||
}()
|
||||
}
|
||||
64
pkg/admin/admin_test.go
Normal file
64
pkg/admin/admin_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func TestGetGlobalConfig(t *testing.T) {
|
||||
app, db := setupTestAdmin(t)
|
||||
defer db.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/config", nil)
|
||||
req.AddCookie(GetVIPCookie(app.Store))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
app.HandleGetConfig(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK, got %d. Body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
var config domain.AppConfig
|
||||
if err := json.NewDecoder(rr.Body).Decode(&config); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if config.Timezone != "America/New_York" || config.BusinessStart != 9 {
|
||||
t.Errorf("Expected default config, got TZ: %s, Start: %d", config.Timezone, config.BusinessStart)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleDeactivateUser(t *testing.T) {
|
||||
h, db := setupTestAdmin(t)
|
||||
defer db.Close()
|
||||
|
||||
targetUser, _ := h.Store.CreateUser(context.Background(), "fired@ranch.com", "Fired Fred", "hash", "RangeHand")
|
||||
res, _ := db.Exec(`INSERT INTO tickets (title, status, severity, source, dedupe_hash) VALUES ('Freds Task', 'Waiting to be Triaged', 'High', 'Manual', 'fake-hash-123')`)
|
||||
ticketID, _ := res.LastInsertId()
|
||||
db.Exec(`INSERT INTO ticket_assignments (ticket_id, assignee, role) VALUES (?, 'fired@ranch.com', 'RangeHand')`, ticketID)
|
||||
|
||||
targetURL := fmt.Sprintf("/api/admin/users/%d", targetUser.ID)
|
||||
req := httptest.NewRequest(http.MethodDelete, targetURL, nil)
|
||||
req.AddCookie(GetVIPCookie(h.Store))
|
||||
req.SetPathValue("id", fmt.Sprintf("%d", targetUser.ID))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.HandleDeactivateUser(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK, got %d. Body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
var count int
|
||||
db.QueryRow(`SELECT COUNT(*) FROM ticket_assignments WHERE assignee = 'fired@ranch.com'`).Scan(&count)
|
||||
if count != 0 {
|
||||
t.Errorf("Expected assignments to be cleared, but found %d", count)
|
||||
}
|
||||
}
|
||||
106
pkg/admin/admin_users_test.go
Normal file
106
pkg/admin/admin_users_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandleAdminResetPassword(t *testing.T) {
|
||||
a, db := setupTestAdmin(t)
|
||||
defer db.Close()
|
||||
|
||||
targetUser, _ := a.Store.CreateUser(context.Background(), "forgetful@ranch.com", "Forgetful Fred", "old_hash", "RangeHand")
|
||||
|
||||
payload := map[string]string{
|
||||
"new_password": "BrandNewSecurePassword123!",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
targetURL := fmt.Sprintf("/api/admin/users/%d/reset-password", targetUser.ID)
|
||||
req := httptest.NewRequest(http.MethodPatch, targetURL, bytes.NewBuffer(body))
|
||||
|
||||
req.SetPathValue("id", fmt.Sprintf("%d", targetUser.ID))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
a.HandleAdminResetPassword(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK, got %d. Body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUpdateUserRole(t *testing.T) {
|
||||
a, db := setupTestAdmin(t)
|
||||
defer db.Close()
|
||||
|
||||
_, _ = a.Store.CreateUser(context.Background(), "boss@ranch.com", "The Boss", "hash", "Sheriff")
|
||||
targetUser, _ := a.Store.CreateUser(context.Background(), "rookie@ranch.com", "Rookie Ray", "hash", "RangeHand")
|
||||
|
||||
payload := map[string]string{
|
||||
"global_role": "Wrangler",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
targetURL := fmt.Sprintf("/api/admin/users/%d/role", targetUser.ID)
|
||||
req := httptest.NewRequest(http.MethodPatch, targetURL, bytes.NewBuffer(body))
|
||||
|
||||
req.AddCookie(GetVIPCookie(a.Store))
|
||||
req.SetPathValue("id", fmt.Sprintf("%d", targetUser.ID))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
a.HandleUpdateUserRole(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK, got %d. Body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleCreateUser_SheriffInvite(t *testing.T) {
|
||||
a, db := setupTestAdmin(t)
|
||||
defer db.Close()
|
||||
|
||||
payload := map[string]string{
|
||||
"email": "magistrate@ranch.com",
|
||||
"full_name": "Mighty Magistrate",
|
||||
"password": "TempPassword123!",
|
||||
"global_role": "Magistrate",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/admin/users", bytes.NewBuffer(body))
|
||||
|
||||
req.AddCookie(GetVIPCookie(a.Store))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
a.HandleCreateUser(rr, req)
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("Expected 201 Created, got %d. Body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
var count int
|
||||
db.QueryRow(`SELECT COUNT(*) FROM users WHERE email = 'magistrate@ranch.com'`).Scan(&count)
|
||||
if count != 1 {
|
||||
t.Errorf("Expected user to be created in the database")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleGetUsers(t *testing.T) {
|
||||
a, db := setupTestAdmin(t)
|
||||
defer db.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/admin/users", nil)
|
||||
|
||||
req.AddCookie(GetVIPCookie(a.Store))
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
a.HandleGetUsers(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK, got %d. Body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
44
pkg/admin/export_test.go
Normal file
44
pkg/admin/export_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func TestExportSystemState(t *testing.T) {
|
||||
app, db := setupTestAdmin(t)
|
||||
defer db.Close()
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO tickets (title, severity, status, dedupe_hash)
|
||||
VALUES ('Export Test Vuln', 'High', 'Triaged', 'test_hash_123')
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert test ticket: %v", err)
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/admin/export", nil)
|
||||
req.AddCookie(GetVIPCookie(app.Store))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
app.HandleExportState(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK, got %d", rr.Code)
|
||||
}
|
||||
|
||||
if rr.Header().Get("Content-Disposition") != "attachment; filename=RiskRancher_export.json" {
|
||||
t.Errorf("Missing or incorrect Content-Disposition header")
|
||||
}
|
||||
|
||||
var state domain.ExportState
|
||||
if err := json.NewDecoder(rr.Body).Decode(&state); err != nil {
|
||||
t.Fatalf("Failed to decode exported JSON: %v", err)
|
||||
}
|
||||
|
||||
if len(state.Tickets) == 0 || state.Tickets[0].Title != "Export Test Vuln" {
|
||||
t.Errorf("Export did not contain the expected ticket data")
|
||||
}
|
||||
}
|
||||
15
pkg/admin/handler.go
Normal file
15
pkg/admin/handler.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
// Handler encapsulates all Admin and Sheriff HTTP logic
|
||||
type Handler struct {
|
||||
Store domain.Store
|
||||
}
|
||||
|
||||
// NewHandler creates a new Admin Handler
|
||||
func NewHandler(store domain.Store) *Handler {
|
||||
return &Handler{Store: store}
|
||||
}
|
||||
30
pkg/admin/helpers_test.go
Normal file
30
pkg/admin/helpers_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/datastore"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
// setupTestAdmin returns the clean Admin Handler and the raw DB
|
||||
func setupTestAdmin(t *testing.T) (*Handler, *sql.DB) {
|
||||
db := datastore.InitDB(":memory:")
|
||||
store := datastore.NewSQLiteStore(db)
|
||||
return NewHandler(store), db
|
||||
}
|
||||
|
||||
// GetVIPCookie creates a dummy Sheriff user to bypass the Bouncer in tests
|
||||
func GetVIPCookie(store domain.Store) *http.Cookie {
|
||||
user, err := store.GetUserByEmail(context.Background(), "vip_test@RiskRancher.com")
|
||||
if err != nil {
|
||||
user, _ = store.CreateUser(context.Background(), "vip_test@RiskRancher.com", "Test VIP", "hash", "Sheriff")
|
||||
}
|
||||
token := "vip_test_token_999"
|
||||
store.CreateSession(context.Background(), token, user.ID, time.Now().Add(1*time.Hour))
|
||||
return &http.Cookie{Name: "session_token", Value: token}
|
||||
}
|
||||
36
pkg/admin/updates_test.go
Normal file
36
pkg/admin/updates_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckUpdates_OfflineFallback(t *testing.T) {
|
||||
|
||||
app, db := setupTestAdmin(t)
|
||||
defer db.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/admin/check-updates", nil)
|
||||
req.AddCookie(GetVIPCookie(app.Store))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
app.HandleCheckUpdates(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK, got %d", rr.Code)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.NewDecoder(rr.Body).Decode(&response); err != nil {
|
||||
t.Fatalf("Failed to decode response: %v", err)
|
||||
}
|
||||
|
||||
if _, exists := response["status"]; !exists {
|
||||
t.Errorf("Expected 'status' field in response")
|
||||
}
|
||||
if _, exists := response["message"]; !exists {
|
||||
t.Errorf("Expected 'message' field in response")
|
||||
}
|
||||
}
|
||||
17
pkg/analytics/analytics.go
Normal file
17
pkg/analytics/analytics.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package analytics
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func (h *Handler) HandleGetAnalyticsSummary(w http.ResponseWriter, r *http.Request) {
|
||||
summary, err := h.Store.GetAnalyticsSummary(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to generate analytics", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(summary)
|
||||
}
|
||||
60
pkg/analytics/analytics_test.go
Normal file
60
pkg/analytics/analytics_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package analytics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/datastore"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func setupTestAnalytics(t *testing.T) (*Handler, *sql.DB) {
|
||||
db := datastore.InitDB(":memory:")
|
||||
store := datastore.NewSQLiteStore(db)
|
||||
return NewHandler(store), db
|
||||
}
|
||||
|
||||
func GetVIPCookie(store domain.Store) *http.Cookie {
|
||||
user, _ := store.CreateUser(context.Background(), "vip@RiskRancher.com", "Test VIP", "hash", "Sheriff")
|
||||
store.CreateSession(context.Background(), "vip_token_999", user.ID, time.Now().Add(1*time.Hour))
|
||||
return &http.Cookie{Name: "session_token", Value: "vip_token_999"}
|
||||
}
|
||||
|
||||
func TestAnalyticsSummary(t *testing.T) {
|
||||
h, db := setupTestAnalytics(t)
|
||||
defer db.Close()
|
||||
|
||||
_, err := db.Exec(`INSERT INTO tickets (source, title, severity, status, dedupe_hash) VALUES
|
||||
('Trivy', 'Container CVE', 'Critical', 'Waiting to be Triaged', 'hash1'),
|
||||
('Trivy', 'Old Lib', 'High', 'Waiting to be Triaged', 'hash2'),
|
||||
('Trivy', 'Patched Lib', 'Critical', 'Patched', 'hash3'),
|
||||
('Manual Pentest', 'SQLi', 'Critical', 'Waiting to be Triaged', 'hash4')
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert dummy data: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/analytics/summary", nil)
|
||||
req.AddCookie(GetVIPCookie(h.Store))
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.HandleGetAnalyticsSummary(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK, got %d. Body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
var summary map[string]int
|
||||
if err := json.NewDecoder(rr.Body).Decode(&summary); err != nil {
|
||||
t.Fatalf("Failed to decode JSON: %v", err)
|
||||
}
|
||||
|
||||
if summary["Total_Open"] != 3 {
|
||||
t.Errorf("Expected 3 total open tickets, got %d", summary["Total_Open"])
|
||||
}
|
||||
}
|
||||
13
pkg/analytics/handler.go
Normal file
13
pkg/analytics/handler.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package analytics
|
||||
|
||||
import (
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Store domain.Store
|
||||
}
|
||||
|
||||
func NewHandler(store domain.Store) *Handler {
|
||||
return &Handler{Store: store}
|
||||
}
|
||||
41
pkg/auth/auth.go
Normal file
41
pkg/auth/auth.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"math/rand"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Handler encapsulates all Identity and Access HTTP logic
|
||||
type Handler struct {
|
||||
Store domain.Store
|
||||
}
|
||||
|
||||
// NewHandler creates a new Auth Handler
|
||||
func NewHandler(store domain.Store) *Handler {
|
||||
return &Handler{Store: store}
|
||||
}
|
||||
|
||||
// HashPassword takes a plaintext password, automatically generates a secure salt
|
||||
func HashPassword(password string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
// CheckPasswordHash securely compares a plaintext password with a stored bcrypt hash.
|
||||
func CheckPasswordHash(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// GenerateSessionToken creates a cryptographically secure random string
|
||||
func GenerateSessionToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
140
pkg/auth/auth_handlers.go
Normal file
140
pkg/auth/auth_handlers.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const SessionCookieName = "session_token"
|
||||
|
||||
// RegisterRequest represents the JSON payload expected for user registration.
|
||||
type RegisterRequest struct {
|
||||
Email string `json:"email"`
|
||||
FullName string `json:"full_name"`
|
||||
Password string `json:"password"`
|
||||
GlobalRole string `json:"global_role"`
|
||||
}
|
||||
|
||||
// LoginRequest represents the JSON payload expected for user login.
|
||||
type LoginRequest struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// HandleRegister processes new user signups.
|
||||
func (h *Handler) HandleRegister(w http.ResponseWriter, r *http.Request) {
|
||||
var req RegisterRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
count, err := h.Store.GetUserCount(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
http.Error(w, "Forbidden: System already initialized. Contact your Sheriff for an account.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
req.GlobalRole = "Sheriff"
|
||||
|
||||
if req.Email == "" || req.Password == "" || req.FullName == "" {
|
||||
http.Error(w, "Missing required fields", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
hashedPassword, err := HashPassword(req.Password)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to hash password", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.Store.CreateUser(r.Context(), req.Email, req.FullName, hashedPassword, req.GlobalRole)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
|
||||
http.Error(w, "Email already exists", http.StatusConflict)
|
||||
return
|
||||
}
|
||||
http.Error(w, "Failed to create user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
|
||||
// HandleLogin authenticates a user and issues a session cookie.
|
||||
func (h *Handler) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var req LoginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.Store.GetUserByEmail(r.Context(), req.Email)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if !CheckPasswordHash(req.Password, user.PasswordHash) {
|
||||
http.Error(w, "Invalid credentials", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := GenerateSessionToken()
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to generate session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
if err := h.Store.CreateSession(r.Context(), token, user.ID, expiresAt); err != nil {
|
||||
http.Error(w, "Failed to persist session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: token,
|
||||
Expires: expiresAt,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: false, // Set to TRUE in production for HTTPS!
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(user)
|
||||
}
|
||||
|
||||
// HandleLogout destroys the user's session in the database and clears their cookie.
|
||||
func (h *Handler) HandleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie(SessionCookieName)
|
||||
|
||||
if err == nil && cookie.Value != "" {
|
||||
_ = h.Store.DeleteSession(r.Context(), cookie.Value)
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: SessionCookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
Expires: time.Unix(0, 0),
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: true, // Ensures it's only sent over HTTPS
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"message": "Successfully logged out",
|
||||
})
|
||||
}
|
||||
111
pkg/auth/auth_handlers_test.go
Normal file
111
pkg/auth/auth_handlers_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/datastore"
|
||||
)
|
||||
|
||||
func setupTestAuth(t *testing.T) (*Handler, *sql.DB) {
|
||||
db := datastore.InitDB(":memory:")
|
||||
|
||||
store := datastore.NewSQLiteStore(db)
|
||||
|
||||
h := NewHandler(store)
|
||||
|
||||
return h, db
|
||||
}
|
||||
|
||||
func TestAuthHandlers(t *testing.T) {
|
||||
a, db := setupTestAuth(t)
|
||||
defer db.Close()
|
||||
|
||||
t.Run("Successful Registration", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"email": "admin@RiskRancher.com",
|
||||
"full_name": "Doc Holliday",
|
||||
"password": "SuperSecretPassword123!",
|
||||
"global_role": "Sheriff", // Use a valid role!
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/register", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
a.HandleRegister(rr, req)
|
||||
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("Expected 201 Created for registration, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Successful Login Issues Cookie", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"email": "admin@RiskRancher.com",
|
||||
"password": "SuperSecretPassword123!",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
a.HandleLogin(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK for successful login, got %d", rr.Code)
|
||||
}
|
||||
|
||||
cookies := rr.Result().Cookies()
|
||||
if len(cookies) == 0 {
|
||||
t.Fatalf("Expected a session cookie to be set, but none was found")
|
||||
}
|
||||
if cookies[0].Name != "session_token" {
|
||||
t.Errorf("Expected cookie named 'session_token', got '%s'", cookies[0].Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Failed Login Rejects Access", func(t *testing.T) {
|
||||
payload := map[string]string{
|
||||
"email": "admin@RiskRancher.com",
|
||||
"password": "WrongPassword!",
|
||||
}
|
||||
body, _ := json.Marshal(payload)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
a.HandleLogin(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("Expected 401 Unauthorized for wrong password, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleLogout(t *testing.T) {
|
||||
a, db := setupTestAuth(t)
|
||||
defer db.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/auth/logout", nil)
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: SessionCookieName,
|
||||
Value: "fake-session-token-123",
|
||||
}
|
||||
req.AddCookie(cookie)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
a.HandleLogout(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK, got %d. Body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
49
pkg/auth/auth_test.go
Normal file
49
pkg/auth/auth_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPasswordHashing(t *testing.T) {
|
||||
password := "SuperSecretSOCPassword123!"
|
||||
|
||||
hash, err := HashPassword(password)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
if hash == password {
|
||||
t.Fatalf("Security failure: Hash matches plain text!")
|
||||
}
|
||||
if len(hash) == 0 {
|
||||
t.Fatalf("Hash is empty")
|
||||
}
|
||||
|
||||
isValid := CheckPasswordHash(password, hash)
|
||||
if !isValid {
|
||||
t.Errorf("Expected valid password to match hash, but it failed")
|
||||
}
|
||||
|
||||
isInvalid := CheckPasswordHash("WrongPassword!", hash)
|
||||
if isInvalid {
|
||||
t.Errorf("Security failure: Incorrect password returned true!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSessionToken(t *testing.T) {
|
||||
|
||||
token1, err1 := GenerateSessionToken()
|
||||
token2, err2 := GenerateSessionToken()
|
||||
|
||||
if err1 != nil || err2 != nil {
|
||||
t.Fatalf("Failed to generate session tokens")
|
||||
}
|
||||
|
||||
if len(token1) < 32 {
|
||||
t.Errorf("Token is too short for security standards: %d chars", len(token1))
|
||||
}
|
||||
|
||||
if token1 == token2 {
|
||||
t.Errorf("CRITICAL: RNG generated the exact same token twice: %s", token1)
|
||||
}
|
||||
}
|
||||
56
pkg/auth/middleware.go
Normal file
56
pkg/auth/middleware.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const UserIDKey contextKey = "user_id"
|
||||
|
||||
func (h *Handler) RequireAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized: Missing session cookie", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := h.Store.GetSession(r.Context(), cookie.Value)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized: Invalid session", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
http.Error(w, "Unauthorized: Session expired", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), UserIDKey, session.UserID)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// RequireUIAuth checks for a valid session and redirects to /login if it fails,
|
||||
func (h *Handler) RequireUIAuth(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err != nil {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
session, err := h.Store.GetSession(r.Context(), cookie.Value)
|
||||
if err != nil || session.ExpiresAt.Before(time.Now()) {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), UserIDKey, session.UserID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
61
pkg/auth/middleware_test.go
Normal file
61
pkg/auth/middleware_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRequireAuthMiddleware(t *testing.T) {
|
||||
h, db := setupTestAuth(t)
|
||||
defer db.Close()
|
||||
|
||||
user, err := h.Store.CreateUser(context.Background(), "vip@RiskRancher.com", "Wyatt Earp", "fake_hash", "Sheriff")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to seed test user: %v", err)
|
||||
}
|
||||
|
||||
validToken := "valid_test_token_123"
|
||||
expiresAt := time.Now().Add(1 * time.Hour)
|
||||
err = h.Store.CreateSession(context.Background(), validToken, user.ID, expiresAt)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to seed test session: %v", err)
|
||||
}
|
||||
|
||||
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Welcome to the VIP room"))
|
||||
})
|
||||
protectedHandler := h.RequireAuth(dummyHandler)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cookieName string
|
||||
cookieValue string
|
||||
expectedStatus int
|
||||
}{
|
||||
{"Missing Cookie", "", "", http.StatusUnauthorized},
|
||||
{"Wrong Cookie Name", "wrong_name", validToken, http.StatusUnauthorized},
|
||||
{"Invalid Token", "session_token", "fake_invalid_token", http.StatusUnauthorized},
|
||||
{"Valid Token", "session_token", validToken, http.StatusOK},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
|
||||
if tt.cookieName != "" {
|
||||
req.AddCookie(&http.Cookie{Name: tt.cookieName, Value: tt.cookieValue})
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
protectedHandler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
74
pkg/auth/rbac_middleware.go
Normal file
74
pkg/auth/rbac_middleware.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// RequireRole acts as the checker
|
||||
func (h *Handler) RequireRole(requiredRole string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userIDVal := r.Context().Value(UserIDKey)
|
||||
if userIDVal == nil {
|
||||
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := userIDVal.(int)
|
||||
if !ok {
|
||||
http.Error(w, "Internal Server Error: Invalid user context", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.Store.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
http.Error(w, "Forbidden: User not found", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if user.GlobalRole != requiredRole {
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAnyRole allows access if the user has ANY of the provided roles.
|
||||
func (h *Handler) RequireAnyRole(allowedRoles ...string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userIDVal := r.Context().Value(UserIDKey)
|
||||
if userIDVal == nil {
|
||||
http.Error(w, "Unauthorized: No user context", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := userIDVal.(int)
|
||||
if !ok {
|
||||
http.Error(w, "Internal Server Error: Invalid user context", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.Store.GetUserByID(r.Context(), userID)
|
||||
if err != nil {
|
||||
http.Error(w, "Forbidden: User not found", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
for _, role := range allowedRoles {
|
||||
if user.GlobalRole == role {
|
||||
// Match found! Open the door.
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
http.Error(w, "Forbidden: Insufficient permissions", http.StatusForbidden)
|
||||
})
|
||||
}
|
||||
}
|
||||
49
pkg/auth/rbac_middleware_test.go
Normal file
49
pkg/auth/rbac_middleware_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequireRoleMiddleware(t *testing.T) {
|
||||
a, db := setupTestAuth(t)
|
||||
defer db.Close()
|
||||
|
||||
sheriff, _ := a.Store.CreateUser(context.Background(), "sheriff@ranch.com", "Wyatt Earp", "hash", "Sheriff")
|
||||
rangeHand, _ := a.Store.CreateUser(context.Background(), "hand@ranch.com", "Jesse James", "hash", "RangeHand")
|
||||
|
||||
vipHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("Welcome to the Manager's Office"))
|
||||
})
|
||||
|
||||
protectedHandler := a.RequireRole("Sheriff")(vipHandler)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int
|
||||
expectedStatus int
|
||||
}{
|
||||
{"Valid Sheriff Access", sheriff.ID, http.StatusOK},
|
||||
{"Denied RangeHand Access", rangeHand.ID, http.StatusForbidden},
|
||||
{"Unknown User", 9999, http.StatusForbidden},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/admin/passwords", nil)
|
||||
|
||||
ctx := context.WithValue(req.Context(), UserIDKey, tt.userID)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
protectedHandler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
187
pkg/datastore/auth_db.go
Normal file
187
pkg/datastore/auth_db.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
// ErrNotFound is a standard error we can use across our handlers
|
||||
var ErrNotFound = errors.New("record not found")
|
||||
|
||||
func (s *SQLiteStore) CreateUser(ctx context.Context, email, fullName, passwordHash, globalRole string) (*domain.User, error) {
|
||||
query := `INSERT INTO users (email, full_name, password_hash, global_role) VALUES (?, ?, ?, ?)`
|
||||
|
||||
result, err := s.DB.ExecContext(ctx, query, email, fullName, passwordHash, globalRole)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &domain.User{
|
||||
ID: int(id),
|
||||
Email: email,
|
||||
FullName: fullName,
|
||||
PasswordHash: passwordHash,
|
||||
GlobalRole: globalRole,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetUserByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
var user domain.User
|
||||
query := "SELECT id, email, password_hash, full_name, global_role FROM users WHERE email = ? AND is_active = 1"
|
||||
|
||||
err := s.DB.QueryRowContext(ctx, query, email).Scan(
|
||||
&user.ID,
|
||||
&user.Email,
|
||||
&user.PasswordHash,
|
||||
&user.FullName,
|
||||
&user.GlobalRole,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, sql.ErrNoRows // Bouncer says no (either wrong email, or deactivated)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) CreateSession(ctx context.Context, token string, userID int, expiresAt time.Time) error {
|
||||
query := `INSERT INTO sessions (session_token, user_id, expires_at) VALUES (?, ?, ?)`
|
||||
_, err := s.DB.ExecContext(ctx, query, token, userID, expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetSession(ctx context.Context, token string) (*domain.Session, error) {
|
||||
query := `SELECT session_token, user_id, expires_at FROM sessions WHERE session_token = ?`
|
||||
|
||||
var session domain.Session
|
||||
err := s.DB.QueryRowContext(ctx, query, token).Scan(
|
||||
&session.Token,
|
||||
&session.UserID,
|
||||
&session.ExpiresAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// GetUserByID fetches a user's full record, including their role
|
||||
func (s *SQLiteStore) GetUserByID(ctx context.Context, id int) (*domain.User, error) {
|
||||
query := `SELECT id, email, full_name, password_hash, global_role FROM users WHERE id = ?`
|
||||
|
||||
var user domain.User
|
||||
err := s.DB.QueryRowContext(ctx, query, id).Scan(
|
||||
&user.ID,
|
||||
&user.Email,
|
||||
&user.FullName,
|
||||
&user.PasswordHash,
|
||||
&user.GlobalRole,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateUserPassword allows an administrator to overwrite a forgotten password
|
||||
func (s *SQLiteStore) UpdateUserPassword(ctx context.Context, id int, newPasswordHash string) error {
|
||||
query := `UPDATE users SET password_hash = ? WHERE id = ?`
|
||||
|
||||
_, err := s.DB.ExecContext(ctx, query, newPasswordHash, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateUserRole promotes or demotes a user by updating their global_role.
|
||||
func (s *SQLiteStore) UpdateUserRole(ctx context.Context, id int, newRole string) error {
|
||||
query := `UPDATE users SET global_role = ? WHERE id = ?`
|
||||
|
||||
_, err := s.DB.ExecContext(ctx, query, newRole, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeactivateUserAndReassign securely offboards a user, kicks them out
|
||||
func (s *SQLiteStore) DeactivateUserAndReassign(ctx context.Context, userID int) error {
|
||||
var email string
|
||||
if err := s.DB.QueryRowContext(ctx, "SELECT email FROM users WHERE id = ?", userID).Scan(&email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := s.DB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.ExecContext(ctx, `UPDATE users SET is_active = 0 WHERE id = ?`, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, `DELETE FROM ticket_assignments WHERE assignee = ?`, email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, `DELETE FROM sessions WHERE user_id = ?`, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// GetUserCount returns the total number of registered users in the system.
|
||||
func (s *SQLiteStore) GetUserCount(ctx context.Context) (int, error) {
|
||||
var count int
|
||||
err := s.DB.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetAllUsers(ctx context.Context) ([]*domain.User, error) {
|
||||
// Notice the return type is now []*domain.User
|
||||
rows, err := s.DB.QueryContext(ctx, "SELECT id, email, full_name, global_role FROM users WHERE is_active = 1")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []*domain.User
|
||||
for rows.Next() {
|
||||
var u domain.User
|
||||
if err := rows.Scan(&u.ID, &u.Email, &u.FullName, &u.GlobalRole); err == nil {
|
||||
users = append(users, &u) // 🚀 Appending the memory address!
|
||||
}
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// DeleteSession removes the token from the database so it can never be used again.
|
||||
func (s *SQLiteStore) DeleteSession(ctx context.Context, token string) error {
|
||||
_, err := s.DB.ExecContext(ctx, `DELETE FROM sessions WHERE token = ?`, token)
|
||||
return err
|
||||
}
|
||||
73
pkg/datastore/auth_db_test.go
Normal file
73
pkg/datastore/auth_db_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUserAndSessionLifecycle(t *testing.T) {
|
||||
store := setupTestDB(t)
|
||||
defer store.DB.Close()
|
||||
|
||||
_, err := store.DB.Exec(`
|
||||
CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT UNIQUE, full_name TEXT, password_hash TEXT, global_role TEXT, is_active BOOLEAN DEFAULT 1);
|
||||
CREATE TABLE sessions (session_token TEXT PRIMARY KEY, user_id INTEGER, expires_at DATETIME);
|
||||
`)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
user, err := store.CreateUser(ctx, "admin@RiskRancher.com", "doc", "fake_bcrypt_hash", "Admin")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user: %v", err)
|
||||
}
|
||||
if user.ID == 0 {
|
||||
t.Errorf("Expected database to return a valid auto-incremented ID, got 0")
|
||||
}
|
||||
|
||||
_, err = store.CreateUser(ctx, "admin@RiskRancher.com", "doc", "another_hash", "Analyst")
|
||||
if err == nil {
|
||||
t.Fatalf("Security Failure: Database allowed a duplicate email address!")
|
||||
}
|
||||
|
||||
fetchedUser, err := store.GetUserByEmail(ctx, "admin@RiskRancher.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch user by email: %v", err)
|
||||
}
|
||||
if fetchedUser.GlobalRole != "Admin" {
|
||||
t.Errorf("Expected role 'Admin', got '%s'", fetchedUser.GlobalRole)
|
||||
}
|
||||
|
||||
expires := time.Now().Add(24 * time.Hour)
|
||||
err = store.CreateSession(ctx, "fake_secure_token", user.ID, expires)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session: %v", err)
|
||||
}
|
||||
|
||||
session, err := store.GetSession(ctx, "fake_secure_token")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve session: %v", err)
|
||||
}
|
||||
if session.UserID != user.ID {
|
||||
t.Errorf("Session mapped to wrong user! Expected %d, got %d", user.ID, session.UserID)
|
||||
}
|
||||
|
||||
userByID, err := store.GetUserByID(ctx, user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to fetch user by ID: %v", err)
|
||||
}
|
||||
if userByID.Email != user.Email {
|
||||
t.Errorf("GetUserByID returned wrong user. Expected %s, got %s", user.Email, userByID.Email)
|
||||
}
|
||||
|
||||
newHash := "new_secure_bcrypt_hash_999"
|
||||
err = store.UpdateUserPassword(ctx, user.ID, newHash)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update user password: %v", err)
|
||||
}
|
||||
|
||||
updatedUser, _ := store.GetUserByID(ctx, user.ID)
|
||||
if updatedUser.PasswordHash != newHash {
|
||||
t.Errorf("Password hash did not update in the database")
|
||||
}
|
||||
}
|
||||
92
pkg/datastore/concurrency_test.go
Normal file
92
pkg/datastore/concurrency_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// runChaosEngine fires 100 concurrent workers at the provided database connection
|
||||
func runChaosEngine(db *sql.DB) int {
|
||||
db.Exec(`CREATE TABLE IF NOT EXISTS tickets (id INTEGER PRIMARY KEY AUTOINCREMENT, title TEXT, status TEXT)`)
|
||||
db.Exec(`INSERT INTO tickets (title, status) VALUES ('Seed', 'Open')`)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, 1000)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 20; i++ {
|
||||
tx, _ := db.Begin()
|
||||
for j := 0; j < 50; j++ {
|
||||
tx.Exec(`INSERT INTO tickets (title, status) VALUES ('Vuln', 'Open')`)
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for w := 0; w < 20; w++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 20; i++ {
|
||||
if _, err := db.Exec(`UPDATE tickets SET status = 'Patched' WHERE id = 1`); err != nil {
|
||||
errCh <- err
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for r := 0; r < 79; r++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 50; i++ {
|
||||
rows, err := db.Query(`SELECT COUNT(*) FROM tickets`)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
} else {
|
||||
rows.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
errorCount := 0
|
||||
for range errCh {
|
||||
errorCount++
|
||||
}
|
||||
return errorCount
|
||||
}
|
||||
|
||||
func TestSQLiteConcurrency_Tuned_Succeeds(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
dbPath := filepath.Join(tempDir, "tuned.db")
|
||||
|
||||
dsn := fmt.Sprintf("%s?_journal_mode=WAL&_synchronous=NORMAL&_busy_timeout=5000", dbPath)
|
||||
db, err := sql.Open("sqlite3", dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open tuned DB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(25)
|
||||
|
||||
errors := runChaosEngine(db)
|
||||
|
||||
if errors > 0 {
|
||||
t.Fatalf("FAILED! Tuned engine threw %d errors. It should have queued them perfectly.", errors)
|
||||
}
|
||||
t.Log("SUCCESS: 100 concurrent workers survived SQLite chaos with ZERO locked errors.")
|
||||
}
|
||||
94
pkg/datastore/db.go
Normal file
94
pkg/datastore/db.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
_ "embed"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
//go:embed schema.sql
|
||||
var schemaSQL string
|
||||
|
||||
//go:embed defaults/*.json
|
||||
var defaultAdaptersFS embed.FS
|
||||
|
||||
func InitDB(filepath string) *sql.DB {
|
||||
dsn := "file:" + filepath + "?_journal=WAL&_timeout=5000&_sync=1&_fk=1"
|
||||
|
||||
db, err := sql.Open("sqlite3", dsn)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(25)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
|
||||
migrations := []string{
|
||||
schemaSQL,
|
||||
}
|
||||
|
||||
if err := RunMigrations(db, migrations); err != nil {
|
||||
log.Fatalf("Database upgrade failed! Halting boot to protect data: %v", err)
|
||||
}
|
||||
|
||||
SeedAdapters(db)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// SeedAdapters reads the embedded JSON files and UPSERTs them into SQLite
|
||||
func SeedAdapters(db *sql.DB) {
|
||||
files, err := defaultAdaptersFS.ReadDir("defaults")
|
||||
if err != nil {
|
||||
log.Printf("No default adapters found or failed to read: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
data, err := defaultAdaptersFS.ReadFile("defaults/" + file.Name())
|
||||
if err != nil {
|
||||
log.Printf("Failed to read adapter file %s: %v", file.Name(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
var adapter domain.Adapter
|
||||
if err := json.Unmarshal(data, &adapter); err != nil {
|
||||
log.Printf("Failed to parse adapter JSON %s: %v", file.Name(), err)
|
||||
continue
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO data_adapters (
|
||||
name, source_name, findings_path, mapping_title,
|
||||
mapping_asset, mapping_severity, mapping_description, mapping_remediation
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(name) DO UPDATE SET
|
||||
source_name = excluded.source_name,
|
||||
findings_path = excluded.findings_path,
|
||||
mapping_title = excluded.mapping_title,
|
||||
mapping_asset = excluded.mapping_asset,
|
||||
mapping_severity = excluded.mapping_severity,
|
||||
mapping_description = excluded.mapping_description,
|
||||
mapping_remediation = excluded.mapping_remediation,
|
||||
updated_at = CURRENT_TIMESTAMP;
|
||||
`
|
||||
|
||||
_, err = db.Exec(query,
|
||||
adapter.Name, adapter.SourceName, adapter.FindingsPath, adapter.MappingTitle,
|
||||
adapter.MappingAsset, adapter.MappingSeverity, adapter.MappingDescription, adapter.MappingRemediation,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to seed adapter %s to DB: %v", adapter.Name, err)
|
||||
} else {
|
||||
log.Printf("🔌 Successfully loaded adapter: %s", adapter.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
10
pkg/datastore/defaults/trivy.json
Normal file
10
pkg/datastore/defaults/trivy.json
Normal file
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"name": "Trivy Container Scan",
|
||||
"source_name": "Trivy",
|
||||
"findings_path": "Results.0.Vulnerabilities",
|
||||
"mapping_title": "VulnerabilityID",
|
||||
"mapping_asset": "PkgName",
|
||||
"mapping_severity": "Severity",
|
||||
"mapping_description": "Title",
|
||||
"mapping_remediation": "FixedVersion"
|
||||
}
|
||||
84
pkg/datastore/diff_test.go
Normal file
84
pkg/datastore/diff_test.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
_ "github.com/mattn/go-sqlite3" // We need the SQLite driver for the test
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *SQLiteStore {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open in-memory SQLite database: %v", err)
|
||||
}
|
||||
|
||||
store := &SQLiteStore{DB: db}
|
||||
return store
|
||||
}
|
||||
|
||||
func TestIngestionDiffEngine(t *testing.T) {
|
||||
store := setupTestDB(t)
|
||||
defer store.DB.Close()
|
||||
_, err := store.DB.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS sla_policies (domain TEXT, severity TEXT, days_to_remediate INTEGER, max_extensions INTEGER, days_to_triage INTEGER);
|
||||
CREATE TABLE IF NOT EXISTS routing_rules (id INTEGER, rule_type TEXT, match_value TEXT, assignee TEXT, role TEXT);
|
||||
CREATE TABLE IF NOT EXISTS ticket_assignments (ticket_id INTEGER, assignee TEXT, role TEXT);
|
||||
CREATE TABLE IF NOT EXISTS ticket_activity (ticket_id INTEGER, actor TEXT, activity_type TEXT, new_value TEXT);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tickets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source TEXT,
|
||||
asset_identifier TEXT,
|
||||
title TEXT,
|
||||
severity TEXT,
|
||||
description TEXT,
|
||||
status TEXT,
|
||||
dedupe_hash TEXT UNIQUE,
|
||||
patched_at DATETIME,
|
||||
domain TEXT,
|
||||
triage_due_date DATETIME,
|
||||
remediation_due_date DATETIME
|
||||
)`)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create schema: %v", err)
|
||||
}
|
||||
|
||||
store.DB.Exec(`INSERT INTO tickets (source, asset_identifier, title, severity, description, status, dedupe_hash) VALUES
|
||||
('Trivy', 'Server-A', 'Old Vuln', 'High', 'Desc', 'Waiting to be Triaged', 'hash_1_open')`)
|
||||
|
||||
store.DB.Exec(`INSERT INTO tickets (source, asset_identifier, title, severity, description, status, dedupe_hash) VALUES
|
||||
('Trivy', 'Server-A', 'Old Vuln', 'High', 'Desc', 'Waiting to be Triaged', 'hash_1_open')`)
|
||||
|
||||
store.DB.Exec(`INSERT INTO tickets (source, asset_identifier, title, severity, description, status, dedupe_hash) VALUES
|
||||
('Trivy', 'Server-A', 'Regressed Vuln', 'High', 'Desc', 'Patched', 'hash_2_patched')`)
|
||||
incomingPayload := []domain.Ticket{
|
||||
{Source: "Trivy", AssetIdentifier: "Server-A", Title: "Regressed Vuln", DedupeHash: "hash_2_patched"},
|
||||
{Source: "Trivy", AssetIdentifier: "Server-A", Title: "Brand New Vuln", DedupeHash: "hash_3_new"},
|
||||
}
|
||||
|
||||
err = store.ProcessIngestionBatch(context.Background(), "Trivy", "Server-A", incomingPayload)
|
||||
if err != nil {
|
||||
t.Fatalf("Diff Engine failed: %v", err)
|
||||
}
|
||||
|
||||
var status string
|
||||
|
||||
store.DB.QueryRow(`SELECT status FROM tickets WHERE dedupe_hash = 'hash_1_open'`).Scan(&status)
|
||||
if status != "Patched" {
|
||||
t.Errorf("Expected hash_1_open to be Auto-Patched, got %s", status)
|
||||
}
|
||||
|
||||
store.DB.QueryRow(`SELECT status FROM tickets WHERE dedupe_hash = 'hash_2_patched'`).Scan(&status)
|
||||
if status != "Waiting to be Triaged" {
|
||||
t.Errorf("Expected hash_2_patched to be Re-opened, got %s", status)
|
||||
}
|
||||
|
||||
store.DB.QueryRow(`SELECT status FROM tickets WHERE dedupe_hash = 'hash_3_new'`).Scan(&status)
|
||||
if status != "Waiting to be Triaged" {
|
||||
t.Errorf("Expected hash_3_new to be newly created, got %s", status)
|
||||
}
|
||||
}
|
||||
58
pkg/datastore/migrate.go
Normal file
58
pkg/datastore/migrate.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
// RunMigrations ensures the database schema matches the binary version
|
||||
func RunMigrations(db *sql.DB, migrations []string) error {
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create schema_migrations table: %v", err)
|
||||
}
|
||||
|
||||
var currentVersion int
|
||||
err = db.QueryRow("SELECT IFNULL(MAX(version), 0) FROM schema_migrations").Scan(¤tVersion)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("failed to read current schema version: %v", err)
|
||||
}
|
||||
|
||||
for i, query := range migrations {
|
||||
migrationVersion := i + 1
|
||||
|
||||
if migrationVersion > currentVersion {
|
||||
log.Printf("🚀 Applying database migration v%d...", migrationVersion)
|
||||
|
||||
// Start a transaction so if the ALTER TABLE fails, it rolls back cleanly
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(query); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("migration v%d failed: %v", migrationVersion, err)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec("INSERT INTO schema_migrations (version) VALUES (?)", migrationVersion); err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("failed to record migration v%d: %v", migrationVersion, err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("✅ Migration v%d applied successfully.", migrationVersion)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
42
pkg/datastore/migrate_test.go
Normal file
42
pkg/datastore/migrate_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestSchemaMigrations(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open test db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
migrations := []string{
|
||||
`CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT);`,
|
||||
`ALTER TABLE users ADD COLUMN email TEXT;`,
|
||||
}
|
||||
|
||||
err = RunMigrations(db, migrations)
|
||||
if err != nil {
|
||||
t.Fatalf("Initial migration failed: %v", err)
|
||||
}
|
||||
|
||||
var version int
|
||||
db.QueryRow("SELECT MAX(version) FROM schema_migrations").Scan(&version)
|
||||
if version != 2 {
|
||||
t.Errorf("Expected database to be at version 2, got %d", version)
|
||||
}
|
||||
|
||||
err = RunMigrations(db, migrations)
|
||||
if err != nil {
|
||||
t.Fatalf("Idempotent migration failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec("INSERT INTO users (name, email) VALUES ('Tim', 'tim@ranch.com')")
|
||||
if err != nil {
|
||||
t.Errorf("Migration 2 did not apply correctly! Column 'email' missing: %v", err)
|
||||
}
|
||||
}
|
||||
147
pkg/datastore/schema.sql
Normal file
147
pkg/datastore/schema.sql
Normal file
@@ -0,0 +1,147 @@
|
||||
CREATE TABLE IF NOT EXISTS app_config (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||
timezone TEXT DEFAULT 'America/New_York',
|
||||
business_start INTEGER DEFAULT 9,
|
||||
business_end INTEGER DEFAULT 17,
|
||||
default_extension_days INTEGER DEFAULT 30,
|
||||
backup_enabled BOOLEAN DEFAULT 1,
|
||||
backup_interval_hours INTEGER DEFAULT 24,
|
||||
backup_retention_days INTEGER DEFAULT 30
|
||||
);
|
||||
|
||||
INSERT OR IGNORE INTO app_config (id) VALUES (1);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS domains (name TEXT PRIMARY KEY);
|
||||
INSERT OR IGNORE INTO domains (name) VALUES ('Vulnerability'), ('Privacy'), ('Compliance'), ('Incident');
|
||||
|
||||
CREATE TABLE IF NOT EXISTS departments (name TEXT PRIMARY KEY);
|
||||
INSERT OR IGNORE INTO departments (name) VALUES ('Security'), ('IT'), ('Privacy'), ('Legal'), ('Compliance');
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sla_policies (
|
||||
domain TEXT NOT NULL,
|
||||
severity TEXT NOT NULL,
|
||||
days_to_triage INTEGER NOT NULL DEFAULT 3,
|
||||
days_to_remediate INTEGER NOT NULL,
|
||||
max_extensions INTEGER NOT NULL DEFAULT 3,
|
||||
PRIMARY KEY (domain, severity),
|
||||
FOREIGN KEY(domain) REFERENCES domains(name) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
INSERT OR IGNORE INTO sla_policies (domain, severity, days_to_triage, days_to_remediate, max_extensions) VALUES
|
||||
('Vulnerability', 'Critical', 3, 14, 1), ('Vulnerability', 'High', 3, 30, 2),
|
||||
('Privacy', 'Critical', 3, 3, 0), ('Privacy', 'High', 3, 7, 1),
|
||||
('Incident', 'Critical', 3, 1, 0);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
full_name TEXT NOT NULL,
|
||||
global_role TEXT NOT NULL CHECK(global_role IN ('Sheriff', 'RangeHand', 'Wrangler', 'CircuitRider', 'Magistrate')),
|
||||
department TEXT NOT NULL DEFAULT 'Security',
|
||||
is_active BOOLEAN DEFAULT 1,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(department) REFERENCES departments(name) ON DELETE SET DEFAULT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
session_token TEXT PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
expires_at DATETIME NOT NULL,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tickets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
domain TEXT NOT NULL DEFAULT 'Vulnerability',
|
||||
source TEXT NOT NULL DEFAULT 'Manual',
|
||||
asset_identifier TEXT NOT NULL DEFAULT 'Default',
|
||||
cve_id TEXT,
|
||||
audit_id TEXT UNIQUE,
|
||||
compliance_tags TEXT,
|
||||
title TEXT NOT NULL,
|
||||
description TEXT,
|
||||
recommended_remediation TEXT,
|
||||
severity TEXT NOT NULL,
|
||||
status TEXT DEFAULT 'Waiting to be Triaged'
|
||||
CHECK(status IN (
|
||||
'Waiting to be Triaged',
|
||||
'Returned to Security',
|
||||
'Triaged',
|
||||
'Assigned Out',
|
||||
'Patched',
|
||||
'False Positive'
|
||||
)),
|
||||
dedupe_hash TEXT UNIQUE NOT NULL,
|
||||
patch_evidence TEXT,
|
||||
accessible_to_internet BOOLEAN DEFAULT 0,
|
||||
assignee TEXT DEFAULT 'Unassigned',
|
||||
latest_comment TEXT DEFAULT '',
|
||||
|
||||
assigned_at DATETIME,
|
||||
owner_viewed_at DATETIME,
|
||||
triage_due_date DATETIME,
|
||||
remediation_due_date DATETIME,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
patched_at DATETIME,
|
||||
FOREIGN KEY(domain) REFERENCES domains(name) ON DELETE SET DEFAULT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_tickets_status ON tickets(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_tickets_severity ON tickets(severity);
|
||||
CREATE INDEX IF NOT EXISTS idx_tickets_domain ON tickets(domain);
|
||||
CREATE INDEX IF NOT EXISTS idx_tickets_source_asset ON tickets(source, asset_identifier);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ticket_assignments (
|
||||
ticket_id INTEGER NOT NULL,
|
||||
assignee TEXT NOT NULL,
|
||||
role TEXT NOT NULL CHECK(role IN ('RangeHand', 'Wrangler', 'Magistrate')),
|
||||
PRIMARY KEY (ticket_id, assignee, role),
|
||||
FOREIGN KEY(ticket_id) REFERENCES tickets(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS data_adapters (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
source_name TEXT NOT NULL,
|
||||
findings_path TEXT NOT NULL DEFAULT '.',
|
||||
mapping_title TEXT NOT NULL,
|
||||
mapping_asset TEXT NOT NULL,
|
||||
mapping_severity TEXT NOT NULL,
|
||||
mapping_description TEXT,
|
||||
mapping_remediation TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sync_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
records_processed INTEGER NOT NULL,
|
||||
error_message TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS draft_tickets (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
report_id TEXT NOT NULL,
|
||||
title TEXT DEFAULT '',
|
||||
description TEXT,
|
||||
severity TEXT DEFAULT 'Medium',
|
||||
asset_identifier TEXT DEFAULT '',
|
||||
recommended_remediation TEXT DEFAULT '',
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_draft_tickets_report_id ON draft_tickets(report_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_assignments_assignee ON ticket_assignments(assignee);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_tickets_status_asset ON tickets(status, asset_identifier);
|
||||
CREATE INDEX IF NOT EXISTS idx_tickets_updated_at ON tickets(updated_at);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_tickets_analytics ON tickets(status, severity, source);
|
||||
CREATE INDEX IF NOT EXISTS idx_tickets_due_dates ON tickets(status, remediation_due_date, triage_due_date);
|
||||
CREATE INDEX IF NOT EXISTS idx_tickets_source_status ON tickets(source, status);
|
||||
17
pkg/datastore/sqlite.go
Normal file
17
pkg/datastore/sqlite.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
type SQLiteStore struct {
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
var _ domain.TicketStore = (*SQLiteStore)(nil)
|
||||
|
||||
func NewSQLiteStore(db *sql.DB) *SQLiteStore {
|
||||
return &SQLiteStore{DB: db}
|
||||
}
|
||||
173
pkg/datastore/sqlite_admin.go
Normal file
173
pkg/datastore/sqlite_admin.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
domain2 "epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func (s *SQLiteStore) UpdateAppConfig(ctx context.Context, config domain2.AppConfig) error {
|
||||
query := `
|
||||
INSERT INTO app_config (id, timezone, business_start, business_end, default_extension_days)
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
timezone = excluded.timezone,
|
||||
business_start = excluded.business_start,
|
||||
business_end = excluded.business_end,
|
||||
default_extension_days = excluded.default_extension_days
|
||||
`
|
||||
_, err := s.DB.ExecContext(ctx, query, config.Timezone, config.BusinessStart, config.BusinessEnd, config.DefaultExtensionDays)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetAppConfig(ctx context.Context) (domain2.AppConfig, error) {
|
||||
var c domain2.AppConfig
|
||||
|
||||
query := `SELECT timezone, business_start, business_end, default_extension_days,
|
||||
backup_enabled, backup_interval_hours, backup_retention_days
|
||||
FROM app_config WHERE id = 1`
|
||||
|
||||
err := s.DB.QueryRowContext(ctx, query).Scan(
|
||||
&c.Timezone, &c.BusinessStart, &c.BusinessEnd, &c.DefaultExtensionDays,
|
||||
&c.Backup.Enabled, &c.Backup.IntervalHours, &c.Backup.RetentionDays,
|
||||
)
|
||||
return c, err
|
||||
}
|
||||
|
||||
// buildSLAMap creates a fast 2D lookup table: map[Domain][Severity]Policy
|
||||
func (s *SQLiteStore) buildSLAMap(ctx context.Context) (map[string]map[string]domain2.SLAPolicy, error) {
|
||||
policies, err := s.GetSLAPolicies(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slaMap := make(map[string]map[string]domain2.SLAPolicy)
|
||||
for _, p := range policies {
|
||||
if slaMap[p.Domain] == nil {
|
||||
slaMap[p.Domain] = make(map[string]domain2.SLAPolicy)
|
||||
}
|
||||
slaMap[p.Domain][p.Severity] = p
|
||||
}
|
||||
return slaMap, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) ExportSystemState(ctx context.Context) (domain2.ExportState, error) {
|
||||
var state domain2.ExportState
|
||||
state.Version = "1.1"
|
||||
state.ExportedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
config, err := s.GetAppConfig(ctx)
|
||||
if err == nil {
|
||||
state.AppConfig = config
|
||||
}
|
||||
|
||||
slas, err := s.GetSLAPolicies(ctx)
|
||||
if err == nil {
|
||||
state.SLAPolicies = slas
|
||||
}
|
||||
|
||||
users, err := s.GetAllUsers(ctx)
|
||||
if err == nil {
|
||||
for _, u := range users {
|
||||
u.PasswordHash = ""
|
||||
state.Users = append(state.Users, *u)
|
||||
}
|
||||
}
|
||||
|
||||
adapters, err := s.GetAdapters(ctx)
|
||||
if err == nil {
|
||||
state.Adapters = adapters
|
||||
}
|
||||
|
||||
query := `SELECT id, domain, source, asset_identifier, title, COALESCE(description, ''), severity, status, dedupe_hash, created_at FROM tickets`
|
||||
rows, err := s.DB.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return state, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var t domain2.Ticket
|
||||
if err := rows.Scan(&t.ID, &t.Domain, &t.Source, &t.AssetIdentifier, &t.Title, &t.Description, &t.Severity, &t.Status, &t.DedupeHash, &t.CreatedAt); err == nil {
|
||||
state.Tickets = append(state.Tickets, t)
|
||||
}
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) UpdateBackupPolicy(ctx context.Context, policy domain2.BackupPolicy) error {
|
||||
_, err := s.DB.ExecContext(ctx, `
|
||||
UPDATE app_config
|
||||
SET backup_enabled = ?, backup_interval_hours = ?, backup_retention_days = ?
|
||||
WHERE id = 1`,
|
||||
policy.Enabled, policy.IntervalHours, policy.RetentionDays)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetSLAPolicies(ctx context.Context) ([]domain2.SLAPolicy, error) {
|
||||
rows, err := s.DB.QueryContext(ctx, "SELECT domain, severity, days_to_remediate, max_extensions, days_to_triage FROM sla_policies ORDER BY domain, severity")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var policies []domain2.SLAPolicy
|
||||
for rows.Next() {
|
||||
var p domain2.SLAPolicy
|
||||
rows.Scan(&p.Domain, &p.Severity, &p.DaysToRemediate, &p.MaxExtensions, &p.DaysToTriage)
|
||||
policies = append(policies, p)
|
||||
}
|
||||
return policies, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) UpdateSLAPolicies(ctx context.Context, slas []domain2.SLAPolicy) error {
|
||||
tx, err := s.DB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
UPDATE sla_policies
|
||||
SET days_to_triage = ?, days_to_remediate = ?, max_extensions = ?
|
||||
WHERE domain = ? AND severity = ?`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for _, sla := range slas {
|
||||
_, err = stmt.ExecContext(ctx, sla.DaysToTriage, sla.DaysToRemediate, sla.MaxExtensions, sla.Domain, sla.Severity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetWranglers(ctx context.Context) ([]domain2.User, error) {
|
||||
query := `
|
||||
SELECT id, email, full_name, global_role, is_active, created_at
|
||||
FROM users
|
||||
WHERE global_role = 'Wrangler' AND is_active = 1
|
||||
ORDER BY email ASC
|
||||
`
|
||||
rows, err := s.DB.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var wranglers []domain2.User
|
||||
for rows.Next() {
|
||||
var w domain2.User
|
||||
if err := rows.Scan(&w.ID, &w.Email, &w.FullName, &w.GlobalRole, &w.IsActive, &w.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wranglers = append(wranglers, w)
|
||||
}
|
||||
return wranglers, nil
|
||||
}
|
||||
357
pkg/datastore/sqlite_analytics.go
Normal file
357
pkg/datastore/sqlite_analytics.go
Normal file
@@ -0,0 +1,357 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
domain2 "epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func (s *SQLiteStore) GetSheriffAnalytics(ctx context.Context) (domain2.SheriffAnalytics, error) {
|
||||
var metrics domain2.SheriffAnalytics
|
||||
|
||||
s.DB.QueryRowContext(ctx, "SELECT COUNT(*) FROM tickets WHERE is_cisa_kev = 1 AND status NOT IN ('Patched', 'Risk Accepted', 'False Positive')").Scan(&metrics.ActiveKEVs)
|
||||
s.DB.QueryRowContext(ctx, "SELECT COUNT(*) FROM tickets WHERE severity = 'Critical' AND status NOT IN ('Patched', 'Risk Accepted', 'False Positive')").Scan(&metrics.OpenCriticals)
|
||||
s.DB.QueryRowContext(ctx, "SELECT COUNT(*) FROM tickets WHERE remediation_due_date < CURRENT_TIMESTAMP AND status NOT IN ('Patched', 'Risk Accepted', 'False Positive')").Scan(&metrics.TotalOverdue)
|
||||
|
||||
mttrQuery := `
|
||||
SELECT COALESCE(AVG(julianday(t.patched_at) - julianday(t.created_at)), 0)
|
||||
FROM tickets t
|
||||
WHERE t.status = 'Patched'
|
||||
`
|
||||
var mttrFloat float64
|
||||
s.DB.QueryRowContext(ctx, mttrQuery).Scan(&mttrFloat)
|
||||
metrics.GlobalMTTRDays = int(mttrFloat)
|
||||
|
||||
sourceQuery := `
|
||||
SELECT
|
||||
t.source,
|
||||
SUM(CASE WHEN t.status NOT IN ('Patched', 'Risk Accepted', 'False Positive') THEN 1 ELSE 0 END) as total_open,
|
||||
SUM(CASE WHEN t.severity = 'Critical' AND t.status NOT IN ('Patched', 'Risk Accepted', 'False Positive') THEN 1 ELSE 0 END) as criticals,
|
||||
SUM(CASE WHEN t.is_cisa_kev = 1 AND t.status NOT IN ('Patched', 'Risk Accepted', 'False Positive') THEN 1 ELSE 0 END) as cisa_kevs,
|
||||
SUM(CASE WHEN t.status = 'Waiting to be Triaged' THEN 1 ELSE 0 END) as untriaged,
|
||||
SUM(CASE WHEN t.remediation_due_date < CURRENT_TIMESTAMP AND t.status NOT IN ('Patched', 'Risk Accepted', 'False Positive') THEN 1 ELSE 0 END) as patch_overdue,
|
||||
SUM(CASE WHEN t.status = 'Pending Risk Approval' THEN 1 ELSE 0 END) as pending_risk,
|
||||
|
||||
SUM(CASE WHEN t.status IN ('Patched', 'Risk Accepted', 'False Positive') THEN 1 ELSE 0 END) as total_closed,
|
||||
SUM(CASE WHEN t.status = 'Patched' THEN 1 ELSE 0 END) as patched,
|
||||
SUM(CASE WHEN t.status = 'Risk Accepted' THEN 1 ELSE 0 END) as risk_accepted,
|
||||
SUM(CASE WHEN t.status = 'False Positive' THEN 1 ELSE 0 END) as false_positive
|
||||
FROM tickets t
|
||||
GROUP BY t.source
|
||||
ORDER BY criticals DESC, patch_overdue DESC
|
||||
`
|
||||
rows, err := s.DB.QueryContext(ctx, sourceQuery)
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var sm domain2.SourceMetrics
|
||||
rows.Scan(&sm.Source, &sm.TotalOpen, &sm.Criticals, &sm.CisaKEVs, &sm.Untriaged, &sm.PatchOverdue, &sm.PendingRisk, &sm.TotalClosed, &sm.Patched, &sm.RiskAccepted, &sm.FalsePositive)
|
||||
|
||||
topAssigneeQ := `
|
||||
SELECT COALESCE(ta.assignee, 'Unassigned'), COUNT(t.id) as c
|
||||
FROM tickets t LEFT JOIN ticket_assignments ta ON t.id = ta.ticket_id
|
||||
WHERE t.source = ? AND t.status NOT IN ('Patched', 'Risk Accepted', 'False Positive')
|
||||
GROUP BY ta.assignee ORDER BY c DESC LIMIT 1`
|
||||
|
||||
var assignee string
|
||||
var count int
|
||||
s.DB.QueryRowContext(ctx, topAssigneeQ, sm.Source).Scan(&assignee, &count)
|
||||
if count > 0 {
|
||||
sm.TopAssignee = fmt.Sprintf("%s (%d)", assignee, count)
|
||||
} else {
|
||||
sm.TopAssignee = "N/A"
|
||||
}
|
||||
|
||||
if sm.PatchOverdue > 0 {
|
||||
sm.StrategicNote = "🚨 SLA Breach (Escalate to IT Managers)"
|
||||
} else if sm.Untriaged > 0 {
|
||||
sm.StrategicNote = "⚠️ Triage Bottleneck (Check Analysts)"
|
||||
} else if sm.PendingRisk > 0 {
|
||||
sm.StrategicNote = "⚖️ Blocked by Exec Adjudication"
|
||||
} else if sm.Criticals > 0 {
|
||||
sm.StrategicNote = "🔥 High Risk (Monitor closely)"
|
||||
} else if sm.RiskAccepted > sm.Patched && sm.TotalClosed > 0 {
|
||||
sm.StrategicNote = "👀 High Risk Acceptance Rate (Audit Required)"
|
||||
} else if sm.FalsePositive > sm.Patched && sm.TotalClosed > 0 {
|
||||
sm.StrategicNote = "🔧 Noisy Source (Scanner needs tuning)"
|
||||
} else if sm.TotalClosed > 0 {
|
||||
sm.StrategicNote = "✅ Healthy Resolution Velocity"
|
||||
} else {
|
||||
sm.StrategicNote = "✅ Routine Processing"
|
||||
}
|
||||
|
||||
metrics.SourceHealth = append(metrics.SourceHealth, sm)
|
||||
}
|
||||
}
|
||||
|
||||
sevQuery := `SELECT severity, COUNT(id) FROM tickets WHERE status NOT IN ('Patched', 'Risk Accepted', 'False Positive') GROUP BY severity`
|
||||
rowsSev, err := s.DB.QueryContext(ctx, sevQuery)
|
||||
if err == nil {
|
||||
defer rowsSev.Close()
|
||||
for rowsSev.Next() {
|
||||
var sev string
|
||||
var count int
|
||||
rowsSev.Scan(&sev, &count)
|
||||
metrics.Severity.Total += count
|
||||
switch sev {
|
||||
case "Critical":
|
||||
metrics.Severity.Critical = count
|
||||
case "High":
|
||||
metrics.Severity.High = count
|
||||
case "Medium":
|
||||
metrics.Severity.Medium = count
|
||||
case "Low":
|
||||
metrics.Severity.Low = count
|
||||
case "Info":
|
||||
metrics.Severity.Info = count
|
||||
}
|
||||
}
|
||||
if metrics.Severity.Total > 0 {
|
||||
metrics.Severity.CritPct = int((float64(metrics.Severity.Critical) / float64(metrics.Severity.Total)) * 100)
|
||||
metrics.Severity.HighPct = int((float64(metrics.Severity.High) / float64(metrics.Severity.Total)) * 100)
|
||||
metrics.Severity.MedPct = int((float64(metrics.Severity.Medium) / float64(metrics.Severity.Total)) * 100)
|
||||
metrics.Severity.LowPct = int((float64(metrics.Severity.Low) / float64(metrics.Severity.Total)) * 100)
|
||||
metrics.Severity.InfoPct = int((float64(metrics.Severity.Info) / float64(metrics.Severity.Total)) * 100)
|
||||
}
|
||||
}
|
||||
|
||||
resQuery := `SELECT status, COUNT(id) FROM tickets WHERE status IN ('Patched', 'Risk Accepted', 'False Positive') GROUP BY status`
|
||||
rowsRes, err := s.DB.QueryContext(ctx, resQuery)
|
||||
if err == nil {
|
||||
defer rowsRes.Close()
|
||||
for rowsRes.Next() {
|
||||
var status string
|
||||
var count int
|
||||
rowsRes.Scan(&status, &count)
|
||||
metrics.Resolution.Total += count
|
||||
|
||||
switch status {
|
||||
case "Patched":
|
||||
metrics.Resolution.Patched = count
|
||||
case "Risk Accepted":
|
||||
metrics.Resolution.RiskAccepted = count
|
||||
case "False Positive":
|
||||
metrics.Resolution.FalsePositive = count
|
||||
}
|
||||
}
|
||||
|
||||
if metrics.Resolution.Total > 0 {
|
||||
metrics.Resolution.PatchedPct = int((float64(metrics.Resolution.Patched) / float64(metrics.Resolution.Total)) * 100)
|
||||
metrics.Resolution.RiskAccPct = int((float64(metrics.Resolution.RiskAccepted) / float64(metrics.Resolution.Total)) * 100)
|
||||
metrics.Resolution.FalsePosPct = int((float64(metrics.Resolution.FalsePositive) / float64(metrics.Resolution.Total)) * 100)
|
||||
}
|
||||
}
|
||||
|
||||
assetQuery := `SELECT asset_identifier, COUNT(id) as c FROM tickets WHERE status NOT IN ('Patched', 'Risk Accepted', 'False Positive') GROUP BY asset_identifier ORDER BY c DESC LIMIT 5`
|
||||
rowsAsset, err := s.DB.QueryContext(ctx, assetQuery)
|
||||
if err == nil {
|
||||
defer rowsAsset.Close()
|
||||
var maxAssetCount int
|
||||
for rowsAsset.Next() {
|
||||
var am domain2.AssetMetric
|
||||
rowsAsset.Scan(&am.Asset, &am.Count)
|
||||
if maxAssetCount == 0 {
|
||||
maxAssetCount = am.Count
|
||||
}
|
||||
if maxAssetCount > 0 {
|
||||
am.Percentage = int((float64(am.Count) / float64(maxAssetCount)) * 100)
|
||||
}
|
||||
metrics.TopAssets = append(metrics.TopAssets, am)
|
||||
}
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetDashboardTickets(ctx context.Context, tabStatus, filter, assetFilter, userEmail, userRole string, limit, offset int) ([]domain2.Ticket, int, map[string]int, error) {
|
||||
metrics := map[string]int{
|
||||
"critical": 0,
|
||||
"overdue": 0,
|
||||
"mine": 0,
|
||||
"verification": 0,
|
||||
"returned": 0,
|
||||
}
|
||||
|
||||
scope := ""
|
||||
var scopeArgs []any
|
||||
|
||||
if userRole == "Wrangler" {
|
||||
scope = ` AND LOWER(t.assignee) = LOWER(?)`
|
||||
scopeArgs = append(scopeArgs, userEmail)
|
||||
}
|
||||
|
||||
if userRole != "Sheriff" {
|
||||
var critCount, overCount, mineCount, verifyCount, returnedCount int
|
||||
|
||||
critQ := "SELECT COUNT(t.id) FROM tickets t WHERE t.severity = 'Critical' AND t.status NOT IN ('Patched', 'Risk Accepted', 'False Positive')" + scope
|
||||
s.DB.QueryRowContext(ctx, critQ, scopeArgs...).Scan(&critCount)
|
||||
metrics["critical"] = critCount
|
||||
|
||||
overQ := "SELECT COUNT(t.id) FROM tickets t WHERE t.remediation_due_date < CURRENT_TIMESTAMP AND t.status NOT IN ('Patched', 'Risk Accepted', 'False Positive')" + scope
|
||||
s.DB.QueryRowContext(ctx, overQ, scopeArgs...).Scan(&overCount)
|
||||
metrics["overdue"] = overCount
|
||||
|
||||
mineQ := "SELECT COUNT(t.id) FROM tickets t WHERE LOWER(t.assignee) = LOWER(?) AND t.status NOT IN ('Patched', 'Risk Accepted', 'False Positive')"
|
||||
s.DB.QueryRowContext(ctx, mineQ, userEmail).Scan(&mineCount)
|
||||
metrics["mine"] = mineCount
|
||||
|
||||
verifyQ := "SELECT COUNT(t.id) FROM tickets t WHERE t.status = 'Pending Verification'" + scope
|
||||
s.DB.QueryRowContext(ctx, verifyQ, scopeArgs...).Scan(&verifyCount)
|
||||
metrics["verification"] = verifyCount
|
||||
|
||||
retQ := "SELECT COUNT(t.id) FROM tickets t WHERE t.status = 'Returned to Security'" + scope
|
||||
s.DB.QueryRowContext(ctx, retQ, scopeArgs...).Scan(&returnedCount)
|
||||
metrics["returned"] = returnedCount
|
||||
}
|
||||
|
||||
baseQ := "FROM tickets t WHERE 1=1" + scope
|
||||
var args []any
|
||||
args = append(args, scopeArgs...)
|
||||
|
||||
if assetFilter != "" {
|
||||
baseQ += " AND t.asset_identifier = ?"
|
||||
args = append(args, assetFilter)
|
||||
}
|
||||
|
||||
if tabStatus == "Waiting to be Triaged" || tabStatus == "holding_pen" {
|
||||
baseQ += " AND t.status IN ('Waiting to be Triaged', 'Returned to Security', 'Triaged')"
|
||||
} else if tabStatus == "Exceptions" {
|
||||
baseQ += " AND t.status NOT IN ('Patched', 'Risk Accepted', 'False Positive')"
|
||||
} else if tabStatus == "archives" {
|
||||
baseQ += " AND t.status IN ('Patched', 'Risk Accepted', 'False Positive')"
|
||||
} else if tabStatus != "" {
|
||||
baseQ += " AND t.status = ?"
|
||||
args = append(args, tabStatus)
|
||||
}
|
||||
|
||||
if filter == "critical" {
|
||||
baseQ += " AND t.severity = 'Critical'"
|
||||
} else if filter == "overdue" {
|
||||
baseQ += " AND t.remediation_due_date < CURRENT_TIMESTAMP"
|
||||
} else if filter == "mine" {
|
||||
baseQ += " AND LOWER(t.assignee) = LOWER(?)"
|
||||
args = append(args, userEmail)
|
||||
} else if tabStatus == "archives" && filter != "" && filter != "all" {
|
||||
baseQ += " AND t.status = ?"
|
||||
args = append(args, filter)
|
||||
}
|
||||
|
||||
var total int
|
||||
s.DB.QueryRowContext(ctx, "SELECT COUNT(t.id) "+baseQ, args...).Scan(&total)
|
||||
|
||||
orderClause := "ORDER BY (CASE WHEN t.status = 'Returned to Security' THEN 0 ELSE 1 END) ASC, t.id DESC"
|
||||
|
||||
query := `
|
||||
WITH PaginatedIDs AS (
|
||||
SELECT t.id ` + baseQ + ` ` + orderClause + ` LIMIT ? OFFSET ?
|
||||
)
|
||||
SELECT
|
||||
t.id, t.source, t.asset_identifier, t.title, COALESCE(t.description, ''), COALESCE(t.recommended_remediation, ''), t.severity, t.status,
|
||||
t.triage_due_date, t.remediation_due_date, COALESCE(t.patch_evidence, ''),
|
||||
t.assignee as current_assignee,
|
||||
t.owner_viewed_at,
|
||||
t.updated_at,
|
||||
CAST(julianday(COALESCE(t.patched_at, t.updated_at)) - julianday(t.created_at) AS INTEGER) as days_to_resolve,
|
||||
COALESCE(t.latest_comment, '') as latest_comment
|
||||
FROM PaginatedIDs p
|
||||
JOIN tickets t ON t.id = p.id
|
||||
` + orderClause
|
||||
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := s.DB.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, 0, metrics, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tickets []domain2.Ticket
|
||||
for rows.Next() {
|
||||
var t domain2.Ticket
|
||||
var assignee string
|
||||
|
||||
err := rows.Scan(
|
||||
&t.ID, &t.Source, &t.AssetIdentifier, &t.Title, &t.Description,
|
||||
&t.RecommendedRemediation, &t.Severity, &t.Status,
|
||||
&t.TriageDueDate, &t.RemediationDueDate, &t.PatchEvidence,
|
||||
&assignee,
|
||||
&t.OwnerViewedAt,
|
||||
&t.UpdatedAt,
|
||||
&t.DaysToResolve,
|
||||
&t.LatestComment,
|
||||
)
|
||||
|
||||
if err == nil {
|
||||
t.Assignee = assignee
|
||||
t.IsOverdue = !t.RemediationDueDate.IsZero() && t.RemediationDueDate.Before(time.Now()) && t.Status != "Patched" && t.Status != "Risk Accepted"
|
||||
|
||||
if tabStatus == "archives" {
|
||||
if t.DaysToResolve != nil {
|
||||
t.SLAString = fmt.Sprintf("%d days", *t.DaysToResolve)
|
||||
} else {
|
||||
t.SLAString = "Unknown"
|
||||
}
|
||||
} else {
|
||||
t.SLAString = t.RemediationDueDate.Format("Jan 02, 2006")
|
||||
}
|
||||
|
||||
tickets = append(tickets, t)
|
||||
}
|
||||
}
|
||||
|
||||
return tickets, total, metrics, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetGlobalActivityFeed(ctx context.Context, limit int) ([]domain2.FeedItem, error) {
|
||||
return []domain2.FeedItem{
|
||||
{
|
||||
Actor: "System",
|
||||
ActivityType: "Info",
|
||||
NewValue: "Detailed Immutable Audit Logging is a RiskRancher Pro feature. Upgrade to track all ticket lifecycle events.",
|
||||
TimeAgo: "Just now",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetAnalyticsSummary(ctx context.Context) (map[string]int, error) {
|
||||
summary := make(map[string]int)
|
||||
|
||||
var total int
|
||||
err := s.DB.QueryRowContext(ctx, `SELECT COUNT(*) FROM tickets WHERE status != 'Patched' AND status != 'Risk Accepted'`).Scan(&total)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
summary["Total_Open"] = total
|
||||
|
||||
sourceRows, err := s.DB.QueryContext(ctx, `SELECT source, COUNT(*) FROM tickets WHERE status != 'Patched' AND status != 'Risk Accepted' GROUP BY source`)
|
||||
if err == nil {
|
||||
defer sourceRows.Close()
|
||||
for sourceRows.Next() {
|
||||
var source string
|
||||
var count int
|
||||
if err := sourceRows.Scan(&source, &count); err == nil {
|
||||
summary["Source_"+source+"_Open"] = count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sevRows, err := s.DB.QueryContext(ctx, `SELECT severity, COUNT(*) FROM tickets WHERE status != 'Patched' AND status != 'Risk Accepted' GROUP BY severity`)
|
||||
if err == nil {
|
||||
defer sevRows.Close()
|
||||
for sevRows.Next() {
|
||||
var sev string
|
||||
var count int
|
||||
if err := sevRows.Scan(&sev, &count); err == nil {
|
||||
summary["Severity_"+sev+"_Open"] = count
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetPaginatedActivityFeed(ctx context.Context, filter string, limit, offset int) ([]domain2.FeedItem, int, error) {
|
||||
return []domain2.FeedItem{}, 0, nil
|
||||
}
|
||||
109
pkg/datastore/sqlite_drafts.go
Normal file
109
pkg/datastore/sqlite_drafts.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
domain2 "epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func (s *SQLiteStore) SaveDraft(ctx context.Context, d domain2.DraftTicket) error {
|
||||
query := `
|
||||
INSERT INTO draft_tickets (report_id, title, description, severity, asset_identifier, recommended_remediation)
|
||||
VALUES (?, ?, ?, ?, ?, ?)`
|
||||
|
||||
_, err := s.DB.ExecContext(ctx, query,
|
||||
d.ReportID, d.Title, d.Description, d.Severity, d.AssetIdentifier, d.RecommendedRemediation)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetDraftsByReport(ctx context.Context, reportID string) ([]domain2.DraftTicket, error) {
|
||||
|
||||
query := `SELECT id, report_id, COALESCE(title, ''), COALESCE(description, ''), COALESCE(severity, 'Medium'), COALESCE(asset_identifier, ''), COALESCE(recommended_remediation, '')
|
||||
FROM draft_tickets WHERE report_id = ?`
|
||||
|
||||
rows, err := s.DB.QueryContext(ctx, query, reportID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var drafts []domain2.DraftTicket
|
||||
for rows.Next() {
|
||||
var d domain2.DraftTicket
|
||||
if err := rows.Scan(&d.ID, &d.ReportID, &d.Title, &d.Description, &d.Severity, &d.AssetIdentifier, &d.RecommendedRemediation); err == nil {
|
||||
drafts = append(drafts, d)
|
||||
}
|
||||
}
|
||||
|
||||
if drafts == nil {
|
||||
drafts = []domain2.DraftTicket{}
|
||||
}
|
||||
return drafts, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) DeleteDraft(ctx context.Context, draftID string) error {
|
||||
query := `DELETE FROM draft_tickets WHERE id = ?`
|
||||
_, err := s.DB.ExecContext(ctx, query, draftID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) UpdateDraft(ctx context.Context, draftID int, payload domain2.Ticket) error {
|
||||
query := `UPDATE draft_tickets SET title = ?, severity = ?, asset_identifier = ?, description = ?, recommended_remediation = ? WHERE id = ?`
|
||||
|
||||
_, err := s.DB.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
payload.Title,
|
||||
payload.Severity,
|
||||
payload.AssetIdentifier,
|
||||
payload.Description,
|
||||
payload.RecommendedRemediation,
|
||||
draftID,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) PromotePentestDrafts(ctx context.Context, reportID string, analystEmail string, tickets []domain2.Ticket) error {
|
||||
tx, err := s.DB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, t := range tickets {
|
||||
hash := fmt.Sprintf("manual-pentest-%s-%s", t.AssetIdentifier, t.Title)
|
||||
|
||||
res, err := tx.ExecContext(ctx, `
|
||||
INSERT INTO tickets (
|
||||
source, asset_identifier, title, description, recommended_remediation, severity, status, dedupe_hash,
|
||||
triage_due_date, remediation_due_date, created_at, updated_at
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, 'Waiting to be Triaged', ?, DATETIME('now', '+3 days'), DATETIME('now', '+14 days'), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
`, "Manual Pentest", t.AssetIdentifier, t.Title, t.Description, t.RecommendedRemediation, t.Severity, hash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ticketID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
INSERT INTO ticket_assignments (ticket_id, assignee, role)
|
||||
VALUES (?, ?, 'RangeHand')
|
||||
`, ticketID, analystEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, "DELETE FROM draft_tickets WHERE report_id = ?", reportID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
284
pkg/datastore/sqlite_ingest.go
Normal file
284
pkg/datastore/sqlite_ingest.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
domain2 "epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func (s *SQLiteStore) IngestTickets(ctx context.Context, tickets []domain2.Ticket) error {
|
||||
tx, err := s.DB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
CREATE TEMP TABLE IF NOT EXISTS staging_tickets (
|
||||
domain TEXT, source TEXT, asset_identifier TEXT, title TEXT,
|
||||
description TEXT, recommended_remediation TEXT, severity TEXT,
|
||||
status TEXT, dedupe_hash TEXT
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tx.ExecContext(ctx, `DELETE FROM staging_tickets`)
|
||||
|
||||
stmt, err := tx.PrepareContext(ctx, `
|
||||
INSERT INTO staging_tickets (domain, source, asset_identifier, title, description, recommended_remediation, severity, status, dedupe_hash)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, t := range tickets {
|
||||
status := t.Status
|
||||
if status == "" {
|
||||
status = "Waiting to be Triaged"
|
||||
}
|
||||
domain := t.Domain
|
||||
if domain == "" {
|
||||
domain = "Vulnerability"
|
||||
}
|
||||
source := t.Source
|
||||
if source == "" {
|
||||
source = "Manual"
|
||||
}
|
||||
|
||||
_, err = stmt.ExecContext(ctx, domain, source, t.AssetIdentifier, t.Title, t.Description, t.RecommendedRemediation, t.Severity, status, t.DedupeHash)
|
||||
if err != nil {
|
||||
stmt.Close()
|
||||
return err
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
INSERT INTO tickets (domain, source, asset_identifier, title, description, recommended_remediation, severity, status, dedupe_hash)
|
||||
SELECT domain, source, asset_identifier, title, description, recommended_remediation, severity, status, dedupe_hash
|
||||
FROM staging_tickets
|
||||
WHERE true -- Prevents SQLite from mistaking 'ON CONFLICT' for a JOIN condition
|
||||
ON CONFLICT(dedupe_hash) DO UPDATE SET
|
||||
description = excluded.description,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx.ExecContext(ctx, `DROP TABLE staging_tickets`)
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetAdapters(ctx context.Context) ([]domain2.Adapter, error) {
|
||||
rows, err := s.DB.QueryContext(ctx, "SELECT id, name, source_name, findings_path, mapping_title, mapping_asset, mapping_severity, mapping_description, mapping_remediation FROM data_adapters")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var adapters []domain2.Adapter
|
||||
for rows.Next() {
|
||||
var a domain2.Adapter
|
||||
rows.Scan(&a.ID, &a.Name, &a.SourceName, &a.FindingsPath, &a.MappingTitle, &a.MappingAsset, &a.MappingSeverity, &a.MappingDescription, &a.MappingRemediation)
|
||||
adapters = append(adapters, a)
|
||||
}
|
||||
return adapters, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) SaveAdapter(ctx context.Context, a domain2.Adapter) error {
|
||||
_, err := s.DB.ExecContext(ctx, `
|
||||
INSERT INTO data_adapters (name, source_name, findings_path, mapping_title, mapping_asset, mapping_severity, mapping_description, mapping_remediation)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
a.Name, a.SourceName, a.FindingsPath, a.MappingTitle, a.MappingAsset, a.MappingSeverity, a.MappingDescription, a.MappingRemediation)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetAdapterByID(ctx context.Context, id int) (domain2.Adapter, error) {
|
||||
var a domain2.Adapter
|
||||
query := `
|
||||
SELECT
|
||||
id, name, source_name, findings_path,
|
||||
mapping_title, mapping_asset, mapping_severity,
|
||||
IFNULL(mapping_description, ''), IFNULL(mapping_remediation, ''),
|
||||
created_at, updated_at
|
||||
FROM data_adapters
|
||||
WHERE id = ?`
|
||||
|
||||
err := s.DB.QueryRowContext(ctx, query, id).Scan(
|
||||
&a.ID, &a.Name, &a.SourceName, &a.FindingsPath,
|
||||
&a.MappingTitle, &a.MappingAsset, &a.MappingSeverity,
|
||||
&a.MappingDescription, &a.MappingRemediation,
|
||||
&a.CreatedAt, &a.UpdatedAt,
|
||||
)
|
||||
return a, err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) DeleteAdapter(ctx context.Context, id int) error {
|
||||
_, err := s.DB.ExecContext(ctx, "DELETE FROM data_adapters WHERE id = ?", id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetAdapterByName(ctx context.Context, name string) (domain2.Adapter, error) {
|
||||
var a domain2.Adapter
|
||||
query := `
|
||||
SELECT
|
||||
id, name, source_name, findings_path,
|
||||
mapping_title, mapping_asset, mapping_severity,
|
||||
IFNULL(mapping_description, ''), IFNULL(mapping_remediation, '')
|
||||
FROM data_adapters
|
||||
WHERE name = ?`
|
||||
|
||||
err := s.DB.QueryRowContext(ctx, query, name).Scan(
|
||||
&a.ID, &a.Name, &a.SourceName, &a.FindingsPath,
|
||||
&a.MappingTitle, &a.MappingAsset, &a.MappingSeverity,
|
||||
&a.MappingDescription, &a.MappingRemediation,
|
||||
)
|
||||
return a, err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) ProcessIngestionBatch(ctx context.Context, source, asset string, incoming []domain2.Ticket) error {
|
||||
slaMap, _ := s.buildSLAMap(ctx)
|
||||
|
||||
tx, err := s.DB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for i := range incoming {
|
||||
if incoming[i].Domain == "" {
|
||||
incoming[i].Domain = "Vulnerability"
|
||||
}
|
||||
if incoming[i].Status == "" {
|
||||
incoming[i].Status = "Waiting to be Triaged"
|
||||
}
|
||||
}
|
||||
|
||||
inserts, reopens, updates, closes, err := s.calculateDiffState(ctx, tx, source, asset, incoming)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.executeBatchMutations(ctx, tx, source, asset, slaMap, inserts, reopens, updates, closes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) calculateDiffState(ctx context.Context, tx *sql.Tx, source, asset string, incoming []domain2.Ticket) (inserts, reopens, descUpdates []domain2.Ticket, autocloses []string, err error) {
|
||||
rows, err := tx.QueryContext(ctx, `SELECT dedupe_hash, status, COALESCE(description, '') FROM tickets WHERE source = ? AND asset_identifier = ?`, source, asset)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type existingRecord struct{ status, description string }
|
||||
existingMap := make(map[string]existingRecord)
|
||||
for rows.Next() {
|
||||
var hash, status, desc string
|
||||
if err := rows.Scan(&hash, &status, &desc); err == nil {
|
||||
existingMap[hash] = existingRecord{status: status, description: desc}
|
||||
}
|
||||
}
|
||||
|
||||
incomingMap := make(map[string]bool)
|
||||
for _, ticket := range incoming {
|
||||
incomingMap[ticket.DedupeHash] = true
|
||||
existing, exists := existingMap[ticket.DedupeHash]
|
||||
if !exists {
|
||||
inserts = append(inserts, ticket)
|
||||
} else {
|
||||
if existing.status == "Patched" {
|
||||
reopens = append(reopens, ticket)
|
||||
}
|
||||
if ticket.Description != "" && ticket.Description != existing.description && existing.status != "Patched" && existing.status != "Risk Accepted" && existing.status != "False Positive" {
|
||||
descUpdates = append(descUpdates, ticket)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for hash, record := range existingMap {
|
||||
if !incomingMap[hash] && record.status != "Patched" && record.status != "Risk Accepted" && record.status != "False Positive" {
|
||||
autocloses = append(autocloses, hash)
|
||||
}
|
||||
}
|
||||
return inserts, reopens, descUpdates, autocloses, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) executeBatchMutations(ctx context.Context, tx *sql.Tx, source, asset string, slaMap map[string]map[string]domain2.SLAPolicy, inserts, reopens, descUpdates []domain2.Ticket, autocloses []string) error {
|
||||
now := time.Now()
|
||||
|
||||
// A. Inserts
|
||||
if len(inserts) > 0 {
|
||||
insertStmt, err := tx.PrepareContext(ctx, `INSERT INTO tickets (source, asset_identifier, title, severity, description, status, dedupe_hash, domain, triage_due_date, remediation_due_date) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer insertStmt.Close()
|
||||
|
||||
for _, t := range inserts {
|
||||
daysToTriage, daysToRemediate := 3, 30
|
||||
if dMap, ok := slaMap[t.Domain]; ok {
|
||||
if policy, ok := dMap[t.Severity]; ok {
|
||||
daysToTriage, daysToRemediate = policy.DaysToTriage, policy.DaysToRemediate
|
||||
}
|
||||
}
|
||||
_, err := insertStmt.ExecContext(ctx, source, asset, t.Title, t.Severity, t.Description, t.Status, t.DedupeHash, t.Domain, now.AddDate(0, 0, daysToTriage), now.AddDate(0, 0, daysToRemediate))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(reopens) > 0 {
|
||||
updateStmt, _ := tx.PrepareContext(ctx, `UPDATE tickets SET status = 'Waiting to be Triaged', patched_at = NULL, triage_due_date = ?, remediation_due_date = ? WHERE dedupe_hash = ?`)
|
||||
defer updateStmt.Close()
|
||||
for _, t := range reopens {
|
||||
updateStmt.ExecContext(ctx, now.AddDate(0, 0, 3), now.AddDate(0, 0, 30), t.DedupeHash) // Using default SLAs for fallback
|
||||
}
|
||||
}
|
||||
|
||||
if len(descUpdates) > 0 {
|
||||
descStmt, _ := tx.PrepareContext(ctx, `UPDATE tickets SET description = ? WHERE dedupe_hash = ?`)
|
||||
defer descStmt.Close()
|
||||
for _, t := range descUpdates {
|
||||
descStmt.ExecContext(ctx, t.Description, t.DedupeHash)
|
||||
}
|
||||
}
|
||||
|
||||
if len(autocloses) > 0 {
|
||||
closeStmt, _ := tx.PrepareContext(ctx, `UPDATE tickets SET status = 'Patched', patched_at = CURRENT_TIMESTAMP WHERE dedupe_hash = ?`)
|
||||
defer closeStmt.Close()
|
||||
for _, hash := range autocloses {
|
||||
closeStmt.ExecContext(ctx, hash)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) LogSync(ctx context.Context, source, status string, records int, errMsg string) error {
|
||||
_, err := s.DB.ExecContext(ctx, `INSERT INTO sync_logs (source, status, records_processed, error_message) VALUES (?, ?, ?, ?)`, source, status, records, errMsg)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetRecentSyncLogs(ctx context.Context, limit int) ([]domain2.SyncLog, error) {
|
||||
rows, err := s.DB.QueryContext(ctx, `SELECT id, source, status, records_processed, IFNULL(error_message, ''), created_at FROM sync_logs ORDER BY id DESC LIMIT ?`, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var logs []domain2.SyncLog
|
||||
for rows.Next() {
|
||||
var l domain2.SyncLog
|
||||
rows.Scan(&l.ID, &l.Source, &l.Status, &l.RecordsProcessed, &l.ErrorMessage, &l.CreatedAt)
|
||||
logs = append(logs, l)
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
131
pkg/datastore/sqlite_tickets.go
Normal file
131
pkg/datastore/sqlite_tickets.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func (s *SQLiteStore) GetTickets(ctx context.Context) ([]domain.Ticket, error) {
|
||||
rows, err := s.DB.QueryContext(ctx, "SELECT id, title, severity, status FROM tickets LIMIT 100")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tickets []domain.Ticket
|
||||
for rows.Next() {
|
||||
var t domain.Ticket
|
||||
rows.Scan(&t.ID, &t.Title, &t.Severity, &t.Status)
|
||||
tickets = append(tickets, t)
|
||||
}
|
||||
return tickets, nil
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) CreateTicket(ctx context.Context, t *domain.Ticket) error {
|
||||
if t.Status == "" {
|
||||
t.Status = "Waiting to be Triaged"
|
||||
}
|
||||
if t.Domain == "" {
|
||||
t.Domain = "Vulnerability"
|
||||
}
|
||||
if t.Source == "" {
|
||||
t.Source = "Manual"
|
||||
}
|
||||
if t.AssetIdentifier == "" {
|
||||
t.AssetIdentifier = "Default"
|
||||
}
|
||||
|
||||
rawHash := fmt.Sprintf("%s-%s-%s-%s", t.Source, t.AssetIdentifier, t.Title, t.Severity)
|
||||
hashBytes := sha256.Sum256([]byte(rawHash))
|
||||
t.DedupeHash = hex.EncodeToString(hashBytes[:])
|
||||
|
||||
query := `
|
||||
INSERT INTO tickets (
|
||||
domain, source, asset_identifier, title, description, recommended_remediation,
|
||||
severity, status, dedupe_hash,
|
||||
triage_due_date, remediation_due_date, created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, DATETIME('now', '+3 days'), DATETIME('now', '+14 days'), CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
`
|
||||
|
||||
res, err := s.DB.ExecContext(ctx, query,
|
||||
t.Domain, t.Source, t.AssetIdentifier, t.Title, t.Description, t.RecommendedRemediation,
|
||||
t.Severity, t.Status, t.DedupeHash,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, _ := res.LastInsertId()
|
||||
t.ID = int(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTicketInline handles a single UI edit and updates the flattened comment tracking
|
||||
func (s *SQLiteStore) UpdateTicketInline(ctx context.Context, ticketID int, severity, description, remediation, comment, actor, status, assignee string) error {
|
||||
query := `
|
||||
UPDATE tickets
|
||||
SET severity = ?, description = ?, recommended_remediation = ?,
|
||||
status = ?, assignee = ?,
|
||||
latest_comment = CASE WHEN ? != '' THEN ? ELSE latest_comment END,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ?`
|
||||
|
||||
formattedComment := ""
|
||||
if comment != "" {
|
||||
formattedComment = "[" + actor + "] " + comment
|
||||
}
|
||||
|
||||
_, err := s.DB.ExecContext(ctx, query, severity, description, remediation, status, assignee, formattedComment, formattedComment, ticketID)
|
||||
return err
|
||||
}
|
||||
|
||||
// RejectTicketFromWrangler puts a ticket back into the Holding Pen
|
||||
func (s *SQLiteStore) RejectTicketFromWrangler(ctx context.Context, ticketIDs []int, reason, comment string) error {
|
||||
tx, err := s.DB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, id := range ticketIDs {
|
||||
fullComment := "[Wrangler Reject: " + reason + "] " + comment
|
||||
_, err := tx.ExecContext(ctx, "UPDATE tickets SET status = 'Returned to Security', assignee = 'Unassigned', latest_comment = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?", fullComment, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (s *SQLiteStore) GetTicketByID(ctx context.Context, id int) (domain.Ticket, error) {
|
||||
var t domain.Ticket
|
||||
var triageDue, remDue, created, updated string
|
||||
var patchedAt *string
|
||||
|
||||
query := `SELECT id, domain, source, asset_identifier, title, description, recommended_remediation, severity, status, dedupe_hash, triage_due_date, remediation_due_date, created_at, updated_at, patched_at, assignee, latest_comment FROM tickets WHERE id = ?`
|
||||
|
||||
err := s.DB.QueryRowContext(ctx, query, id).Scan(
|
||||
&t.ID, &t.Domain, &t.Source, &t.AssetIdentifier, &t.Title, &t.Description, &t.RecommendedRemediation, &t.Severity, &t.Status, &t.DedupeHash, &triageDue, &remDue, &created, &updated, &patchedAt, &t.Assignee, &t.LatestComment,
|
||||
)
|
||||
if err != nil {
|
||||
return t, err
|
||||
}
|
||||
|
||||
t.TriageDueDate, _ = time.Parse(time.RFC3339, triageDue)
|
||||
t.RemediationDueDate, _ = time.Parse(time.RFC3339, remDue)
|
||||
t.CreatedAt, _ = time.Parse(time.RFC3339, created)
|
||||
t.UpdatedAt, _ = time.Parse(time.RFC3339, updated)
|
||||
|
||||
if patchedAt != nil {
|
||||
pTime, _ := time.Parse(time.RFC3339, *patchedAt)
|
||||
t.PatchedAt = &pTime
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
16
pkg/domain/adapter.go
Normal file
16
pkg/domain/adapter.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package domain
|
||||
|
||||
// Adapter represents a saved mapping profile for a specific scanner
|
||||
type Adapter struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
SourceName string `json:"source_name"`
|
||||
FindingsPath string `json:"findings_path"`
|
||||
MappingTitle string `json:"mapping_title"`
|
||||
MappingAsset string `json:"mapping_asset"`
|
||||
MappingSeverity string `json:"mapping_severity"`
|
||||
MappingDescription string `json:"mapping_description"`
|
||||
MappingRemediation string `json:"mapping_remediation"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
74
pkg/domain/analytics.go
Normal file
74
pkg/domain/analytics.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package domain
|
||||
|
||||
type ResolutionMetrics struct {
|
||||
Total int
|
||||
Patched int
|
||||
RiskAccepted int
|
||||
FalsePositive int
|
||||
PatchedPct int
|
||||
RiskAccPct int
|
||||
FalsePosPct int
|
||||
}
|
||||
|
||||
type SheriffAnalytics struct {
|
||||
ActiveKEVs int
|
||||
GlobalMTTRDays int
|
||||
OpenCriticals int
|
||||
TotalOverdue int
|
||||
SourceHealth []SourceMetrics
|
||||
Resolution ResolutionMetrics
|
||||
Severity SeverityMetrics
|
||||
TopAssets []AssetMetric
|
||||
}
|
||||
|
||||
type SourceMetrics struct {
|
||||
Source string
|
||||
TotalOpen int
|
||||
Criticals int
|
||||
CisaKEVs int
|
||||
Untriaged int
|
||||
PatchOverdue int
|
||||
PendingRisk int
|
||||
TotalClosed int
|
||||
Patched int
|
||||
RiskAccepted int
|
||||
FalsePositive int
|
||||
TopAssignee string
|
||||
StrategicNote string
|
||||
}
|
||||
|
||||
type FeedItem struct {
|
||||
Actor string
|
||||
ActivityType string
|
||||
NewValue string
|
||||
TimeAgo string
|
||||
}
|
||||
|
||||
type SeverityMetrics struct {
|
||||
Critical int
|
||||
High int
|
||||
Medium int
|
||||
Low int
|
||||
Info int
|
||||
Total int
|
||||
CritPct int
|
||||
HighPct int
|
||||
MedPct int
|
||||
LowPct int
|
||||
InfoPct int
|
||||
}
|
||||
|
||||
type AssetMetric struct {
|
||||
Asset string
|
||||
Count int
|
||||
Percentage int
|
||||
}
|
||||
|
||||
type SyncLog struct {
|
||||
ID int `json:"id"`
|
||||
Source string `json:"source"`
|
||||
Status string `json:"status"`
|
||||
RecordsProcessed int `json:"records_processed"`
|
||||
ErrorMessage string `json:"error_message"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
18
pkg/domain/auth.go
Normal file
18
pkg/domain/auth.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package domain
|
||||
|
||||
import "time"
|
||||
|
||||
type User struct {
|
||||
ID int `json:"id"`
|
||||
Email string `json:"email"`
|
||||
FullName string `json:"full_name"`
|
||||
PasswordHash string `json:"-"`
|
||||
GlobalRole string `json:"global_role"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
type Session struct {
|
||||
Token string `json:"token"`
|
||||
UserID int `json:"user_id"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
15
pkg/domain/config.go
Normal file
15
pkg/domain/config.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package domain
|
||||
|
||||
type AppConfig struct {
|
||||
Timezone string `json:"timezone"`
|
||||
BusinessStart int `json:"business_start"`
|
||||
BusinessEnd int `json:"business_end"`
|
||||
DefaultExtensionDays int `json:"default_extension_days"`
|
||||
Backup BackupPolicy `json:"backup"`
|
||||
}
|
||||
|
||||
type BackupPolicy struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
IntervalHours int `json:"interval_hours"`
|
||||
RetentionDays int `json:"retention_days"`
|
||||
}
|
||||
16
pkg/domain/connector.go
Normal file
16
pkg/domain/connector.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package domain
|
||||
|
||||
// ConnectorTemplate defines how to translate third-party JSON into ticket format
|
||||
type ConnectorTemplate struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
SourceDefault string `json:"source_default"`
|
||||
FindingsArrayPath string `json:"findings_array_path"`
|
||||
FieldMappings struct {
|
||||
Title string `json:"title"`
|
||||
AssetIdentifier string `json:"asset_identifier"`
|
||||
Severity string `json:"severity"`
|
||||
Description string `json:"description"`
|
||||
RecommendedRemediation string `json:"recommended_remediation"`
|
||||
} `json:"field_mappings"`
|
||||
}
|
||||
11
pkg/domain/drafts.go
Normal file
11
pkg/domain/drafts.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package domain
|
||||
|
||||
type DraftTicket struct {
|
||||
ID int `json:"id"`
|
||||
ReportID string `json:"report_id"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
AssetIdentifier string `json:"asset_identifier"`
|
||||
RecommendedRemediation string `json:"recommended_remediation"`
|
||||
}
|
||||
11
pkg/domain/export.go
Normal file
11
pkg/domain/export.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package domain
|
||||
|
||||
type ExportState struct {
|
||||
AppConfig AppConfig `json:"app_config"`
|
||||
SLAPolicies []SLAPolicy `json:"sla_policies"`
|
||||
Users []User `json:"users"`
|
||||
Adapters []Adapter `json:"adapters"`
|
||||
Tickets []Ticket `json:"tickets"`
|
||||
Version string `json:"export_version"`
|
||||
ExportedAt string `json:"exported_at"`
|
||||
}
|
||||
95
pkg/domain/store.go
Normal file
95
pkg/domain/store.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Store embeds all sub interfaces for Core
|
||||
type Store interface {
|
||||
TicketStore
|
||||
IdentityStore
|
||||
IngestStore
|
||||
ConfigStore
|
||||
AnalyticsStore
|
||||
DraftStore
|
||||
}
|
||||
|
||||
// TicketStore: Core CRUD and Workflow
|
||||
type TicketStore interface {
|
||||
GetTickets(ctx context.Context) ([]Ticket, error)
|
||||
GetDashboardTickets(ctx context.Context, tabStatus, filter, assetFilter, userEmail, userRole string, limit, offset int) ([]Ticket, int, map[string]int, error)
|
||||
CreateTicket(ctx context.Context, t *Ticket) error
|
||||
GetTicketByID(ctx context.Context, id int) (Ticket, error)
|
||||
UpdateTicketInline(ctx context.Context, ticketID int, severity, description, remediation, comment, actor, status, assignee string) error
|
||||
}
|
||||
|
||||
// IdentityStore: Users, Sessions, and Dispatching
|
||||
type IdentityStore interface {
|
||||
CreateUser(ctx context.Context, email, fullName, passwordHash, globalRole string) (*User, error)
|
||||
GetUserByEmail(ctx context.Context, email string) (*User, error)
|
||||
GetUserByID(ctx context.Context, id int) (*User, error)
|
||||
GetAllUsers(ctx context.Context) ([]*User, error)
|
||||
GetUserCount(ctx context.Context) (int, error)
|
||||
UpdateUserPassword(ctx context.Context, id int, newPasswordHash string) error
|
||||
UpdateUserRole(ctx context.Context, id int, newRole string) error
|
||||
DeactivateUserAndReassign(ctx context.Context, userID int) error
|
||||
|
||||
CreateSession(ctx context.Context, token string, userID int, expiresAt time.Time) error
|
||||
GetSession(ctx context.Context, token string) (*Session, error)
|
||||
DeleteSession(ctx context.Context, token string) error
|
||||
|
||||
GetWranglers(ctx context.Context) ([]User, error)
|
||||
}
|
||||
|
||||
// IngestStore: Scanners, Adapters, and Sync History
|
||||
type IngestStore interface {
|
||||
IngestTickets(ctx context.Context, tickets []Ticket) error
|
||||
ProcessIngestionBatch(ctx context.Context, source string, assetIdentifier string, incoming []Ticket) error
|
||||
|
||||
GetAdapters(ctx context.Context) ([]Adapter, error)
|
||||
GetAdapterByID(ctx context.Context, id int) (Adapter, error)
|
||||
GetAdapterByName(ctx context.Context, name string) (Adapter, error)
|
||||
SaveAdapter(ctx context.Context, adapter Adapter) error
|
||||
DeleteAdapter(ctx context.Context, id int) error
|
||||
|
||||
LogSync(ctx context.Context, source, status string, records int, errMsg string) error
|
||||
GetRecentSyncLogs(ctx context.Context, limit int) ([]SyncLog, error)
|
||||
}
|
||||
|
||||
// ConfigStore: Global System Settings
|
||||
type ConfigStore interface {
|
||||
GetAppConfig(ctx context.Context) (AppConfig, error)
|
||||
UpdateAppConfig(ctx context.Context, config AppConfig) error
|
||||
GetSLAPolicies(ctx context.Context) ([]SLAPolicy, error)
|
||||
UpdateSLAPolicies(ctx context.Context, slas []SLAPolicy) error
|
||||
UpdateBackupPolicy(ctx context.Context, policy BackupPolicy) error
|
||||
ExportSystemState(ctx context.Context) (ExportState, error)
|
||||
}
|
||||
|
||||
// AnalyticsStore: Audit Logs and KPI Metrics
|
||||
type AnalyticsStore interface {
|
||||
GetSheriffAnalytics(ctx context.Context) (SheriffAnalytics, error)
|
||||
GetAnalyticsSummary(ctx context.Context) (map[string]int, error)
|
||||
GetGlobalActivityFeed(ctx context.Context, limit int) ([]FeedItem, error)
|
||||
GetPaginatedActivityFeed(ctx context.Context, filter string, limit int, offset int) ([]FeedItem, int, error)
|
||||
}
|
||||
|
||||
// DraftStore: The Pentest Desk OSS, word docx
|
||||
type DraftStore interface {
|
||||
SaveDraft(ctx context.Context, draft DraftTicket) error
|
||||
GetDraftsByReport(ctx context.Context, reportID string) ([]DraftTicket, error)
|
||||
DeleteDraft(ctx context.Context, draftID string) error
|
||||
UpdateDraft(ctx context.Context, draftID int, payload Ticket) error
|
||||
PromotePentestDrafts(ctx context.Context, reportID string, analystEmail string, tickets []Ticket) error
|
||||
}
|
||||
|
||||
type Authenticator interface {
|
||||
Middleware(next http.Handler) http.Handler
|
||||
}
|
||||
|
||||
type SLACalculator interface {
|
||||
CalculateDueDate(severity string) *time.Time
|
||||
CalculateTrueSLAHours(ctx context.Context, ticketID int, store Store) (float64, error)
|
||||
}
|
||||
61
pkg/domain/ticket.go
Normal file
61
pkg/domain/ticket.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SLAPolicy represents the global SLA configuration per severity
|
||||
type SLAPolicy struct {
|
||||
Domain string `json:"domain"`
|
||||
Severity string `json:"severity"`
|
||||
DaysToRemediate int `json:"days_to_remediate"`
|
||||
MaxExtensions int `json:"max_extensions"`
|
||||
DaysToTriage int `json:"days_to_triage"`
|
||||
}
|
||||
|
||||
// AssetRiskSummary holds the rolled-up vulnerability counts for a single asset
|
||||
type AssetRiskSummary struct {
|
||||
AssetIdentifier string
|
||||
TotalActive int
|
||||
Critical int
|
||||
High int
|
||||
Medium int
|
||||
Low int
|
||||
Info int
|
||||
}
|
||||
|
||||
type Ticket struct {
|
||||
ID int `json:"id"`
|
||||
Domain string `json:"domain"`
|
||||
IsOverdue bool `json:"is_overdue"`
|
||||
DaysToResolve *int `json:"days_to_resolve"`
|
||||
Source string `json:"source"`
|
||||
AssetIdentifier string `json:"asset_identifier"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
RecommendedRemediation string `json:"recommended_remediation"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
|
||||
DedupeHash string `json:"dedupe_hash"`
|
||||
|
||||
PatchEvidence *string `json:"patch_evidence"`
|
||||
OwnerViewedAt *time.Time `json:"owner_viewed_at"`
|
||||
|
||||
TriageDueDate time.Time `json:"triage_due_date"`
|
||||
RemediationDueDate time.Time `json:"remediation_due_date"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
PatchedAt *time.Time `json:"patched_at"`
|
||||
|
||||
SLAString string `json:"sla_string"`
|
||||
Assignee string `json:"assignee"`
|
||||
LatestComment string `json:"latest_comment"`
|
||||
}
|
||||
|
||||
// TicketAssignment represents the many-to-many relationship
|
||||
type TicketAssignment struct {
|
||||
TicketID int `json:"ticket_id"`
|
||||
Assignee string `json:"assignee"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
13
pkg/ingest/handler.go
Normal file
13
pkg/ingest/handler.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package ingest
|
||||
|
||||
import (
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Store domain.Store
|
||||
}
|
||||
|
||||
func NewHandler(store domain.Store) *Handler {
|
||||
return &Handler{Store: store}
|
||||
}
|
||||
163
pkg/ingest/ingest.go
Normal file
163
pkg/ingest/ingest.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package ingest
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/csv"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func (h *Handler) HandleIngest(w http.ResponseWriter, r *http.Request) {
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
_, err := decoder.Token()
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid JSON payload: expected array", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
type groupKey struct {
|
||||
Source string
|
||||
Asset string
|
||||
}
|
||||
groupedTickets := make(map[groupKey][]domain.Ticket)
|
||||
for decoder.More() {
|
||||
var ticket domain.Ticket
|
||||
if err := decoder.Decode(&ticket); err != nil {
|
||||
http.Error(w, "Error parsing ticket object", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if ticket.Status == "" {
|
||||
ticket.Status = "Waiting to be Triaged"
|
||||
}
|
||||
|
||||
if ticket.DedupeHash == "" {
|
||||
hashInput := ticket.Source + "|" + ticket.AssetIdentifier + "|" + ticket.Title
|
||||
hash := sha256.Sum256([]byte(hashInput))
|
||||
ticket.DedupeHash = hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
key := groupKey{
|
||||
Source: ticket.Source,
|
||||
Asset: ticket.AssetIdentifier,
|
||||
}
|
||||
groupedTickets[key] = append(groupedTickets[key], ticket)
|
||||
}
|
||||
|
||||
_, err = decoder.Token()
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid JSON payload termination", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
for key, batch := range groupedTickets {
|
||||
err := h.Store.ProcessIngestionBatch(r.Context(), key.Source, key.Asset, batch)
|
||||
if err != nil {
|
||||
log.Printf("🔥 Ingestion DB Error for Asset %s: %v", key.Asset, err)
|
||||
h.Store.LogSync(r.Context(), key.Source, "Failed", len(batch), err.Error())
|
||||
http.Error(w, "Database error processing batch", http.StatusInternalServerError)
|
||||
return
|
||||
} else {
|
||||
h.Store.LogSync(r.Context(), key.Source, "Success", len(batch), "")
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleCSVIngest(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||
http.Error(w, "Failed to parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
adapterIDStr := r.FormValue("adapter_id")
|
||||
adapterID, err := strconv.Atoi(adapterIDStr)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid adapter_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
adapter, err := h.Store.GetAdapterByID(r.Context(), adapterID)
|
||||
if err != nil {
|
||||
http.Error(w, "Adapter mapping not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
file, _, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read file payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
reader := csv.NewReader(file)
|
||||
records, err := reader.ReadAll()
|
||||
if err != nil || len(records) < 2 {
|
||||
http.Error(w, "Invalid or empty CSV format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
headers := records[0]
|
||||
headerMap := make(map[string]int)
|
||||
for i, h := range headers {
|
||||
headerMap[h] = i
|
||||
}
|
||||
|
||||
type groupKey struct {
|
||||
Source string
|
||||
Asset string
|
||||
}
|
||||
groupedTickets := make(map[groupKey][]domain.Ticket)
|
||||
|
||||
for _, row := range records[1:] {
|
||||
ticket := domain.Ticket{
|
||||
Source: adapter.SourceName,
|
||||
Status: "Waiting to be Triaged",
|
||||
}
|
||||
|
||||
if idx, ok := headerMap[adapter.MappingTitle]; ok && idx < len(row) {
|
||||
ticket.Title = row[idx]
|
||||
}
|
||||
if idx, ok := headerMap[adapter.MappingAsset]; ok && idx < len(row) {
|
||||
ticket.AssetIdentifier = row[idx]
|
||||
}
|
||||
if idx, ok := headerMap[adapter.MappingSeverity]; ok && idx < len(row) {
|
||||
ticket.Severity = row[idx]
|
||||
}
|
||||
if idx, ok := headerMap[adapter.MappingDescription]; ok && idx < len(row) {
|
||||
ticket.Description = row[idx]
|
||||
}
|
||||
if adapter.MappingRemediation != "" {
|
||||
if idx, ok := headerMap[adapter.MappingRemediation]; ok && idx < len(row) {
|
||||
ticket.RecommendedRemediation = row[idx]
|
||||
}
|
||||
}
|
||||
|
||||
if ticket.Title != "" && ticket.AssetIdentifier != "" {
|
||||
hashInput := ticket.Source + "|" + ticket.AssetIdentifier + "|" + ticket.Title
|
||||
hash := sha256.Sum256([]byte(hashInput))
|
||||
ticket.DedupeHash = hex.EncodeToString(hash[:])
|
||||
key := groupKey{Source: ticket.Source, Asset: ticket.AssetIdentifier}
|
||||
groupedTickets[key] = append(groupedTickets[key], ticket)
|
||||
}
|
||||
}
|
||||
|
||||
for key, batch := range groupedTickets {
|
||||
err := h.Store.ProcessIngestionBatch(r.Context(), key.Source, key.Asset, batch)
|
||||
if err != nil {
|
||||
log.Printf("🔥 CSV Ingestion Error for Asset %s: %v", key.Asset, err)
|
||||
h.Store.LogSync(r.Context(), key.Source, "Failed", len(batch), err.Error())
|
||||
http.Error(w, "Database error processing CSV batch", http.StatusInternalServerError)
|
||||
return
|
||||
} else {
|
||||
h.Store.LogSync(r.Context(), key.Source, "Success", len(batch), "")
|
||||
}
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
488
pkg/ingest/ingest_test.go
Normal file
488
pkg/ingest/ingest_test.go
Normal file
@@ -0,0 +1,488 @@
|
||||
package ingest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime/debug"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/datastore"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func setupTestIngest(t *testing.T) (*Handler, *sql.DB) {
|
||||
db := datastore.InitDB(":memory:")
|
||||
store := datastore.NewSQLiteStore(db)
|
||||
return NewHandler(store), db
|
||||
}
|
||||
|
||||
func GetVIPCookie(store domain.Store) *http.Cookie {
|
||||
|
||||
user, err := store.GetUserByEmail(context.Background(), "vip@RiskRancher.com")
|
||||
if err != nil {
|
||||
user, _ = store.CreateUser(context.Background(), "vip@RiskRancher.com", "Test VIP", "hash", "Sheriff")
|
||||
}
|
||||
|
||||
store.CreateSession(context.Background(), "vip_token_999", user.ID, time.Now().Add(1*time.Hour))
|
||||
return &http.Cookie{Name: "session_token", Value: "vip_token_999"}
|
||||
}
|
||||
|
||||
func TestAutoPatchMissingFindings(t *testing.T) {
|
||||
app, db := setupTestIngest(t)
|
||||
defer db.Close()
|
||||
|
||||
payload1 := []byte(`[
|
||||
{"title": "Vuln A", "severity": "High"},
|
||||
{"title": "Vuln B", "severity": "Medium"}
|
||||
]
|
||||
`)
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/api/ingest", bytes.NewBuffer(payload1))
|
||||
req1.AddCookie(GetVIPCookie(app.Store))
|
||||
rr1 := httptest.NewRecorder()
|
||||
app.HandleIngest(rr1, req1)
|
||||
|
||||
var count int
|
||||
db.QueryRow("SELECT COUNT(*) FROM tickets WHERE status = 'Waiting to be Triaged'").Scan(&count)
|
||||
if count != 2 {
|
||||
t.Fatalf("Expected 2 unpatched tickets, got %d", count)
|
||||
}
|
||||
|
||||
payload2 := []byte(` [
|
||||
{"title": "Vuln A", "severity": "High"}
|
||||
]`)
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/api/ingest", bytes.NewBuffer(payload2))
|
||||
req2.AddCookie(GetVIPCookie(app.Store))
|
||||
rr2 := httptest.NewRecorder()
|
||||
app.HandleIngest(rr2, req2)
|
||||
|
||||
var statusB string
|
||||
var patchedAt sql.NullTime
|
||||
|
||||
err := db.QueryRow("SELECT status, patched_at FROM tickets WHERE title = 'Vuln B'").Scan(&statusB, &patchedAt)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query Vuln B: %v", err)
|
||||
}
|
||||
|
||||
if statusB != "Patched" {
|
||||
t.Errorf("Expected Vuln B status to be 'Patched', got '%s'", statusB)
|
||||
}
|
||||
|
||||
if !patchedAt.Valid {
|
||||
t.Errorf("Expected Vuln B to have a patched_at timestamp, but it was NULL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleIngest(t *testing.T) {
|
||||
a, db := setupTestIngest(t)
|
||||
defer db.Close()
|
||||
|
||||
sendIngestRequest := func(findings []domain.Ticket) *httptest.ResponseRecorder {
|
||||
body, _ := json.Marshal(findings)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/ingest", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
a.HandleIngest(rr, req)
|
||||
return rr
|
||||
}
|
||||
|
||||
t.Run("1. Fresh Ingestion", func(t *testing.T) {
|
||||
findings := []domain.Ticket{
|
||||
{
|
||||
Source: "CrowdStrike",
|
||||
AssetIdentifier: "Server-01",
|
||||
Title: "Malware Detected",
|
||||
Severity: "Critical",
|
||||
},
|
||||
}
|
||||
|
||||
rr := sendIngestRequest(findings)
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201 Created, got %d", rr.Code)
|
||||
}
|
||||
|
||||
var count int
|
||||
db.QueryRow("SELECT COUNT(*) FROM tickets").Scan(&count)
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 ticket in DB, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("2. Deduplication", func(t *testing.T) {
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
findings := []domain.Ticket{
|
||||
{
|
||||
Source: "CrowdStrike",
|
||||
AssetIdentifier: "Server-01",
|
||||
Title: "Malware Detected",
|
||||
Severity: "Critical",
|
||||
Description: "Updated Description",
|
||||
},
|
||||
}
|
||||
|
||||
rr := sendIngestRequest(findings)
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201 Created, got %d", rr.Code)
|
||||
}
|
||||
|
||||
var count int
|
||||
db.QueryRow("SELECT COUNT(*) FROM tickets").Scan(&count)
|
||||
if count != 1 {
|
||||
t.Errorf("expected still 1 ticket in DB due to dedupe, got %d", count)
|
||||
}
|
||||
|
||||
var desc string
|
||||
db.QueryRow("SELECT description FROM tickets WHERE title = 'Malware Detected'").Scan(&desc)
|
||||
if desc != "Updated Description" {
|
||||
t.Errorf("expected description to update to 'Updated Description', got '%s'", desc)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("3. Scoped Auto-Patching", func(t *testing.T) {
|
||||
findings := []domain.Ticket{
|
||||
{
|
||||
Source: "CrowdStrike",
|
||||
AssetIdentifier: "Server-01",
|
||||
Title: "Outdated Antivirus",
|
||||
Severity: "High",
|
||||
},
|
||||
}
|
||||
|
||||
rr := sendIngestRequest(findings)
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201 Created, got %d", rr.Code)
|
||||
}
|
||||
|
||||
var totalCount int
|
||||
db.QueryRow("SELECT COUNT(*) FROM tickets").Scan(&totalCount)
|
||||
if totalCount != 2 {
|
||||
t.Errorf("expected 2 total tickets in DB, got %d", totalCount)
|
||||
}
|
||||
|
||||
var status string
|
||||
db.QueryRow("SELECT status FROM tickets WHERE title = 'Malware Detected'").Scan(&status)
|
||||
if status != "Patched" {
|
||||
t.Errorf("expected missing vulnerability to be auto-patched, but status is '%s'", status)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCSVIngestion(t *testing.T) {
|
||||
app, db := setupTestIngest(t)
|
||||
defer db.Close()
|
||||
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO data_adapters (
|
||||
id, name, source_name, findings_path,
|
||||
mapping_title, mapping_asset, mapping_severity, mapping_description, mapping_remediation
|
||||
) VALUES (
|
||||
999, 'Legacy Scanner V1', 'LegacyScan', '.',
|
||||
'Vuln_Name', 'Server_IP', 'Risk_Level', 'Details', 'Fix_Steps'
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to setup test adapter: %v", err)
|
||||
}
|
||||
|
||||
rawCSV := `Vuln_Name,Server_IP,Risk_Level,Details,Junk_Column
|
||||
SQL Injection,192.168.1.50,Critical,Found in login form,ignore_this
|
||||
Outdated Apache,192.168.1.50,High,Upgrade to 2.4.50,ignore_this`
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
part, _ := writer.CreateFormFile("file", "scan_results.csv")
|
||||
part.Write([]byte(rawCSV))
|
||||
|
||||
writer.WriteField("adapter_id", "999")
|
||||
writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/ingest/csv", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
app.HandleCSVIngest(rr, req)
|
||||
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("Expected 201 Created, got %d. Body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
var count int
|
||||
db.QueryRow("SELECT COUNT(*) FROM tickets WHERE source = 'LegacyScan'").Scan(&count)
|
||||
|
||||
if count != 2 {
|
||||
t.Errorf("Expected 2 tickets ingested from CSV, got %d", count)
|
||||
}
|
||||
|
||||
var title, severity string
|
||||
db.QueryRow("SELECT title, severity FROM tickets WHERE title = 'SQL Injection'").Scan(&title, &severity)
|
||||
if severity != "Critical" {
|
||||
t.Errorf("CSV Mapping failed! Expected severity 'Critical', got '%s'", severity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoPatchEdgeCases(t *testing.T) {
|
||||
h, db := setupTestIngest(t) // Swapped 'app' for 'h'
|
||||
defer db.Close()
|
||||
|
||||
db.Exec(`
|
||||
INSERT INTO tickets (source, title, severity, dedupe_hash, status)
|
||||
VALUES ('App B', 'App B Vuln', 'High', 'hash-app-b', 'Waiting to be Triaged')
|
||||
`)
|
||||
|
||||
payload1 := []byte(`[
|
||||
{"source": "App A", "title": "Vuln 1", "severity": "High"},
|
||||
{"source": "App A", "title": "Vuln 2", "severity": "Medium"}
|
||||
]`)
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/api/ingest", bytes.NewBuffer(payload1))
|
||||
req1.AddCookie(GetVIPCookie(h.Store))
|
||||
req1.Header.Set("Content-Type", "application/json")
|
||||
|
||||
rr1 := httptest.NewRecorder()
|
||||
h.HandleIngest(rr1, req1)
|
||||
|
||||
payload2 := []byte(`[
|
||||
{"source": "App A", "title": "Vuln 1", "severity": "High"}
|
||||
]`)
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/api/ingest", bytes.NewBuffer(payload2))
|
||||
req2.AddCookie(GetVIPCookie(h.Store))
|
||||
req2.Header.Set("Content-Type", "application/json")
|
||||
|
||||
rr2 := httptest.NewRecorder()
|
||||
h.HandleIngest(rr2, req2)
|
||||
|
||||
var status2 string
|
||||
db.QueryRow("SELECT status FROM tickets WHERE title = 'Vuln 2'").Scan(&status2)
|
||||
if status2 != "Patched" {
|
||||
t.Errorf("Expected Vuln 2 to be 'Patched', got '%s'", status2)
|
||||
}
|
||||
|
||||
var statusB string
|
||||
db.QueryRow("SELECT status FROM tickets WHERE title = 'App B Vuln'").Scan(&statusB)
|
||||
if statusB != "Waiting to be Triaged" {
|
||||
t.Errorf("CRITICAL FAILURE: Blast radius exceeded! App B status changed to '%s'", statusB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleIngest_MultiAssetDiffing(t *testing.T) {
|
||||
// THE GO 1.26 GC TWEAK: Force Go to keep RAM usage under 2GB
|
||||
// This makes the GC run aggressively, trading a tiny bit of CPU for massive RAM savings.
|
||||
previousLimit := debug.SetMemoryLimit(2 * 1024 * 1024 * 1024)
|
||||
defer debug.SetMemoryLimit(previousLimit)
|
||||
|
||||
a, db := setupTestIngest(t)
|
||||
db.Exec(`PRAGMA synchronous = OFF;`)
|
||||
defer db.Close()
|
||||
|
||||
_, err := db.Exec(`INSERT INTO tickets (source, asset_identifier, title, status, severity, dedupe_hash) VALUES
|
||||
('Trivy', 'Server-A', 'Old Vuln A', 'Waiting to be Triaged', 'High', 'hash_A_1'),
|
||||
('Trivy', 'Server-B', 'Old Vuln B', 'Waiting to be Triaged', 'Critical', 'hash_B_1')`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to seed database: %v", err)
|
||||
}
|
||||
|
||||
incomingPayload := []domain.Ticket{
|
||||
{
|
||||
Source: "Trivy",
|
||||
AssetIdentifier: "Server-A",
|
||||
Title: "New Vuln A",
|
||||
Severity: "High",
|
||||
DedupeHash: "hash_A_2",
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(incomingPayload)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/ingest", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
a.HandleIngest(rr, req)
|
||||
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("Expected 201 Created, got %d", rr.Code)
|
||||
}
|
||||
|
||||
var statusA string
|
||||
db.QueryRow(`SELECT status FROM tickets WHERE dedupe_hash = 'hash_A_1'`).Scan(&statusA)
|
||||
if statusA != "Patched" {
|
||||
t.Errorf("Expected Server-A's old ticket to be Auto-Patched, got '%s'", statusA)
|
||||
}
|
||||
|
||||
var statusB string
|
||||
db.QueryRow(`SELECT status FROM tickets WHERE dedupe_hash = 'hash_B_1'`).Scan(&statusB)
|
||||
if statusB != "Waiting to be Triaged" {
|
||||
t.Errorf("CRITICAL BUG: Server-B's ticket was altered! Expected 'Waiting to be Triaged', got '%s'", statusB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleIngest_OneMillionTicketStressTest(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping 1-million ticket stress test in short mode")
|
||||
}
|
||||
|
||||
a, db := setupTestIngest(t)
|
||||
defer db.Close()
|
||||
|
||||
numAssets := 10000
|
||||
vulnsPerAsset := 100
|
||||
|
||||
t.Logf("Generating baseline payload for %d tickets...", numAssets*vulnsPerAsset)
|
||||
|
||||
baselinePayload := make([]domain.Ticket, 0, numAssets*vulnsPerAsset)
|
||||
for assetID := 1; assetID <= numAssets; assetID++ {
|
||||
assetName := fmt.Sprintf("Server-%05d", assetID)
|
||||
for vulnID := 1; vulnID <= vulnsPerAsset; vulnID++ {
|
||||
baselinePayload = append(baselinePayload, domain.Ticket{
|
||||
Source: "HeavyLoadTester",
|
||||
AssetIdentifier: assetName,
|
||||
Title: fmt.Sprintf("Vulnerability-%03d", vulnID),
|
||||
Severity: "High",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("Marshaling 1M tickets to JSON...")
|
||||
body1, _ := json.Marshal(baselinePayload)
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/api/ingest", bytes.NewBuffer(body1))
|
||||
rr1 := httptest.NewRecorder()
|
||||
|
||||
t.Log("Hitting API with Baseline 1M Scan...")
|
||||
a.HandleIngest(rr1, req1)
|
||||
|
||||
if rr1.Code != http.StatusCreated {
|
||||
t.Fatalf("Baseline ingest failed with status %d", rr1.Code)
|
||||
}
|
||||
|
||||
var count1 int
|
||||
db.QueryRow(`SELECT COUNT(*) FROM tickets`).Scan(&count1)
|
||||
if count1 != 1000000 {
|
||||
t.Fatalf("Expected 1,000,000 tickets inserted, got %d", count1)
|
||||
}
|
||||
|
||||
t.Log("Generating Diff payload...")
|
||||
|
||||
diffPayload := make([]domain.Ticket, 0, numAssets*vulnsPerAsset)
|
||||
for assetID := 1; assetID <= numAssets; assetID++ {
|
||||
assetName := fmt.Sprintf("Server-%05d", assetID)
|
||||
|
||||
for vulnID := 1; vulnID <= 80; vulnID++ {
|
||||
diffPayload = append(diffPayload, domain.Ticket{
|
||||
Source: "HeavyLoadTester",
|
||||
AssetIdentifier: assetName,
|
||||
Title: fmt.Sprintf("Vulnerability-%03d", vulnID),
|
||||
Severity: "High",
|
||||
})
|
||||
}
|
||||
|
||||
for vulnID := 101; vulnID <= 120; vulnID++ {
|
||||
diffPayload = append(diffPayload, domain.Ticket{
|
||||
Source: "HeavyLoadTester",
|
||||
AssetIdentifier: assetName,
|
||||
Title: fmt.Sprintf("Vulnerability-%03d", vulnID),
|
||||
Severity: "Critical",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("Marshaling Diff payload to JSON...")
|
||||
body2, _ := json.Marshal(diffPayload)
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/api/ingest", bytes.NewBuffer(body2))
|
||||
rr2 := httptest.NewRecorder()
|
||||
|
||||
t.Log("Hitting API with Diff 1M Scan...")
|
||||
a.HandleIngest(rr2, req2)
|
||||
|
||||
if rr2.Code != http.StatusCreated {
|
||||
t.Fatalf("Diff ingest failed with status %d", rr2.Code)
|
||||
}
|
||||
|
||||
t.Log("Running Assertions...")
|
||||
|
||||
var totalRows int
|
||||
db.QueryRow(`SELECT COUNT(*) FROM tickets`).Scan(&totalRows)
|
||||
if totalRows != 1200000 {
|
||||
t.Errorf("Expected exactly 1,200,000 total rows in DB, got %d", totalRows)
|
||||
}
|
||||
|
||||
var patchedCount int
|
||||
db.QueryRow(`SELECT COUNT(*) FROM tickets WHERE status = 'Patched'`).Scan(&patchedCount)
|
||||
if patchedCount != 200000 {
|
||||
t.Errorf("Expected exactly 200,000 auto-patched tickets, got %d", patchedCount)
|
||||
}
|
||||
|
||||
var openCount int
|
||||
db.QueryRow(`SELECT COUNT(*) FROM tickets WHERE status = 'Waiting to be Triaged'`).Scan(&openCount)
|
||||
if openCount != 1000000 {
|
||||
t.Errorf("Expected exactly 1,000,000 open tickets, got %d", openCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncLogReceipts(t *testing.T) {
|
||||
h, db := setupTestIngest(t)
|
||||
defer db.Close()
|
||||
db.Exec(`CREATE TABLE IF NOT EXISTS sync_logs (id INTEGER PRIMARY KEY, source TEXT, status TEXT, records_processed INTEGER, error_message TEXT)`)
|
||||
|
||||
payload := []byte(`[{"source": "Dependabot", "asset_identifier": "repo-1", "title": "Vuln 1", "severity": "High"}]`)
|
||||
req1 := httptest.NewRequest(http.MethodPost, "/api/ingest", bytes.NewBuffer(payload))
|
||||
req1.AddCookie(GetVIPCookie(h.Store))
|
||||
req1.Header.Set("Content-Type", "application/json")
|
||||
h.HandleIngest(httptest.NewRecorder(), req1)
|
||||
|
||||
badPayload := []byte(`[{"source": "Dependabot", "title": "Vuln 1", "severity": "High", "status": "GarbageStatus"}]`)
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodPost, "/api/ingest", bytes.NewBuffer(badPayload))
|
||||
req2.AddCookie(GetVIPCookie(h.Store))
|
||||
req2.Header.Set("Content-Type", "application/json")
|
||||
h.HandleIngest(httptest.NewRecorder(), req2)
|
||||
|
||||
var successCount, failCount, processed int
|
||||
db.QueryRow("SELECT COUNT(*), MAX(records_processed) FROM sync_logs WHERE source = 'Dependabot' AND status = 'Success'").Scan(&successCount, &processed)
|
||||
db.QueryRow("SELECT COUNT(*) FROM sync_logs WHERE status = 'Failed'").Scan(&failCount)
|
||||
|
||||
if successCount != 1 || processed != 1 {
|
||||
t.Errorf("System failed to log successful sync receipt. Got count: %d, processed: %d", successCount, processed)
|
||||
}
|
||||
if failCount != 1 {
|
||||
t.Errorf("System failed to log failed sync receipt. Got count: %d", failCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUIFileDropIngestion(t *testing.T) {
|
||||
h, db := setupTestIngest(t)
|
||||
defer db.Close()
|
||||
|
||||
res, err := db.Exec(`INSERT INTO data_adapters (name, source_name, mapping_title, mapping_asset, mapping_severity) VALUES ('UI-Tool', 'UITool', 'Name', 'Host', 'Risk')`)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to seed adapter: %v", err)
|
||||
}
|
||||
adapterID, _ := res.LastInsertId()
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
part, _ := writer.CreateFormFile("file", "test_findings.csv")
|
||||
part.Write([]byte("Name,Host,Risk\nUnauthorized Access,10.0.0.1,Critical"))
|
||||
|
||||
_ = writer.WriteField("adapter_id", fmt.Sprintf("%d", adapterID))
|
||||
writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/ingest/csv", body)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
req.AddCookie(GetVIPCookie(h.Store))
|
||||
rr := httptest.NewRecorder()
|
||||
h.HandleCSVIngest(rr, req)
|
||||
|
||||
if rr.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201 Created, got %d: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
var count int
|
||||
db.QueryRow("SELECT COUNT(*) FROM tickets WHERE source = 'UITool'").Scan(&count)
|
||||
if count != 1 {
|
||||
t.Errorf("UI Drop failed: expected 1 ticket, got %d", count)
|
||||
}
|
||||
}
|
||||
113
pkg/report/docx_html.go
Normal file
113
pkg/report/docx_html.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Relationships maps the pkg rId to the actual media file
|
||||
type Relationships struct {
|
||||
XMLName xml.Name `xml:"Relationships"`
|
||||
Rel []struct {
|
||||
Id string `xml:"Id,attr"`
|
||||
Target string `xml:"Target,attr"`
|
||||
} `xml:"Relationship"`
|
||||
}
|
||||
|
||||
func ServeDOCXAsHTML(w http.ResponseWriter, docxPath string) {
|
||||
r, err := zip.OpenReader(docxPath)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to open DOCX archive", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
relsMap := make(map[string]string)
|
||||
for _, f := range r.File {
|
||||
if f.Name == "word/_rels/document.xml.rels" {
|
||||
rc, _ := f.Open()
|
||||
var rels Relationships
|
||||
xml.NewDecoder(rc).Decode(&rels)
|
||||
rc.Close()
|
||||
for _, rel := range rels.Rel {
|
||||
relsMap[rel.Id] = rel.Target
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
mediaMap := make(map[string]string)
|
||||
for _, f := range r.File {
|
||||
if strings.HasPrefix(f.Name, "word/media/") {
|
||||
rc, _ := f.Open()
|
||||
data, _ := io.ReadAll(rc)
|
||||
rc.Close()
|
||||
|
||||
ext := strings.TrimPrefix(filepath.Ext(f.Name), ".")
|
||||
if ext == "jpeg" || ext == "jpg" {
|
||||
ext = "jpeg"
|
||||
}
|
||||
b64 := base64.StdEncoding.EncodeToString(data)
|
||||
mediaMap[f.Name] = fmt.Sprintf("data:image/%s;base64,%s", ext, b64)
|
||||
}
|
||||
}
|
||||
|
||||
var htmlOutput bytes.Buffer
|
||||
var inParagraph bool
|
||||
|
||||
for _, f := range r.File {
|
||||
if f.Name == "word/document.xml" {
|
||||
rc, _ := f.Open()
|
||||
decoder := xml.NewDecoder(rc)
|
||||
|
||||
for {
|
||||
token, err := decoder.Token()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
switch se := token.(type) {
|
||||
case xml.StartElement:
|
||||
if se.Name.Local == "p" {
|
||||
htmlOutput.WriteString("<p style='margin-bottom: 10px;'>")
|
||||
inParagraph = true
|
||||
}
|
||||
if se.Name.Local == "t" {
|
||||
var text string
|
||||
decoder.DecodeElement(&text, &se)
|
||||
htmlOutput.WriteString(text)
|
||||
}
|
||||
if se.Name.Local == "blip" {
|
||||
for _, attr := range se.Attr {
|
||||
if attr.Name.Local == "embed" {
|
||||
targetPath := relsMap[attr.Value]
|
||||
fullMediaPath := "word/" + targetPath
|
||||
|
||||
if b64URI, exists := mediaMap[fullMediaPath]; exists {
|
||||
imgTag := fmt.Sprintf(`<br><img src="%s" style="max-width: 100%%; height: auto; border: 1px solid #cbd5e1; border-radius: 4px; margin: 15px 0; cursor: pointer;" class="pentest-img" title="Click to extract image"><br>`, b64URI)
|
||||
htmlOutput.WriteString(imgTag)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case xml.EndElement:
|
||||
if se.Name.Local == "p" && inParagraph {
|
||||
htmlOutput.WriteString("</p>\n")
|
||||
inParagraph = false
|
||||
}
|
||||
}
|
||||
}
|
||||
rc.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
w.Write(htmlOutput.Bytes())
|
||||
}
|
||||
107
pkg/report/drafts.go
Normal file
107
pkg/report/drafts.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/auth"
|
||||
domain2 "epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func (h *Handler) HandleSaveDraft(w http.ResponseWriter, r *http.Request) {
|
||||
reportID := r.PathValue("id")
|
||||
|
||||
var draft domain2.DraftTicket
|
||||
if err := json.NewDecoder(r.Body).Decode(&draft); err != nil {
|
||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
draft.ReportID = reportID
|
||||
|
||||
if err := h.Store.SaveDraft(r.Context(), draft); err != nil {
|
||||
http.Error(w, "DB Error: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleGetDrafts(w http.ResponseWriter, r *http.Request) {
|
||||
reportID := r.PathValue("id")
|
||||
|
||||
drafts, err := h.Store.GetDraftsByReport(r.Context(), reportID)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get drafts", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(drafts)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleDeleteDraft(w http.ResponseWriter, r *http.Request) {
|
||||
draftID := r.PathValue("draft_id")
|
||||
|
||||
if err := h.Store.DeleteDraft(r.Context(), draftID); err != nil {
|
||||
http.Error(w, "Failed to delete draft", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (h *Handler) HandlePromoteDrafts(w http.ResponseWriter, r *http.Request) {
|
||||
reportIDStr := r.PathValue("id")
|
||||
if reportIDStr == "" {
|
||||
http.Error(w, "Invalid Report ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
userIDVal := r.Context().Value(auth.UserIDKey)
|
||||
if userIDVal == nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.Store.GetUserByID(r.Context(), userIDVal.(int))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to identify user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
analystEmail := user.Email
|
||||
|
||||
var payload []domain2.Ticket
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := h.Store.PromotePentestDrafts(r.Context(), reportIDStr, analystEmail, payload); err != nil {
|
||||
http.Error(w, "Database error during promotion: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleUpdateDraft(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := r.PathValue("id")
|
||||
draftID, err := strconv.Atoi(idStr)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid draft ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var payload domain2.Ticket
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.Store.UpdateDraft(r.Context(), draftID, payload); err != nil {
|
||||
http.Error(w, "Failed to auto-save draft", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
13
pkg/report/handler.go
Normal file
13
pkg/report/handler.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
Store domain.Store
|
||||
}
|
||||
|
||||
func NewHandler(store domain.Store) *Handler {
|
||||
return &Handler{Store: store}
|
||||
}
|
||||
57
pkg/report/parser.go
Normal file
57
pkg/report/parser.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ExtractJSONField traverses an unmarshaled JSON object using dot notation.
|
||||
func ExtractJSONField(data any, path string) string {
|
||||
if path == "" || data == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.Split(path, ".")
|
||||
current := data
|
||||
|
||||
for _, part := range parts {
|
||||
if current == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch v := current.(type) {
|
||||
case map[string]any:
|
||||
val, ok := v[part]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
current = val
|
||||
|
||||
case []any:
|
||||
idx, err := strconv.Atoi(part)
|
||||
if err != nil || idx < 0 || idx >= len(v) {
|
||||
return ""
|
||||
}
|
||||
current = v[idx]
|
||||
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
if current == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch v := current.(type) {
|
||||
case string:
|
||||
return v
|
||||
case float64:
|
||||
return strconv.FormatFloat(v, 'f', -1, 64)
|
||||
case bool:
|
||||
return strconv.FormatBool(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
68
pkg/report/parser_test.go
Normal file
68
pkg/report/parser_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractJSONField(t *testing.T) {
|
||||
semgrepRaw := []byte(`{
|
||||
"check_id": "crypto-bad-mac",
|
||||
"extra": {
|
||||
"severity": "WARNING",
|
||||
"message": "Use of weak MAC"
|
||||
}
|
||||
}`)
|
||||
var semgrep map[string]any
|
||||
json.Unmarshal(semgrepRaw, &semgrep)
|
||||
|
||||
trivyRaw := []byte(`{
|
||||
"VulnerabilityID": "CVE-2021-44228",
|
||||
"PkgName": "log4j-core",
|
||||
"Severity": "CRITICAL"
|
||||
}`)
|
||||
var trivy map[string]any
|
||||
json.Unmarshal(trivyRaw, &trivy)
|
||||
|
||||
openvasRaw := []byte(`{
|
||||
"name": "Cleartext Transmission",
|
||||
"host": {
|
||||
"details": [
|
||||
{"ip": "192.168.1.50"},
|
||||
{"ip": "10.0.0.5"}
|
||||
]
|
||||
},
|
||||
"threat": "High"
|
||||
}`)
|
||||
var openvas map[string]any
|
||||
json.Unmarshal(openvasRaw, &openvas)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
finding any
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{"Semgrep Flat", semgrep, "check_id", "crypto-bad-mac"},
|
||||
{"Semgrep Nested", semgrep, "extra.severity", "WARNING"},
|
||||
{"Semgrep Deep Nested", semgrep, "extra.message", "Use of weak MAC"},
|
||||
|
||||
{"Trivy Flat 1", trivy, "VulnerabilityID", "CVE-2021-44228"},
|
||||
{"Trivy Flat 2", trivy, "Severity", "CRITICAL"},
|
||||
|
||||
{"OpenVAS Flat", openvas, "threat", "High"},
|
||||
{"OpenVAS Array Index", openvas, "host.details.0.ip", "192.168.1.50"},
|
||||
|
||||
{"Missing Field", trivy, "does.not.exist", ""},
|
||||
{"Empty Path", trivy, "", ""},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := ExtractJSONField(tc.finding, tc.path)
|
||||
if result != tc.expected {
|
||||
t.Errorf("Path '%s': expected '%s', got '%s'", tc.path, tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
131
pkg/report/reports.go
Normal file
131
pkg/report/reports.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var UploadDir = "./testdata"
|
||||
|
||||
// HandleUploadReport safely receives and stores the pentest file
|
||||
func (h *Handler) HandleUploadReport(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(50 << 20); err != nil {
|
||||
http.Error(w, "Failed to parse form or file too large", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
http.Error(w, "Missing 'file' field in upload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
cleanName := filepath.Base(header.Filename)
|
||||
if cleanName == "." || cleanName == "/" {
|
||||
cleanName = "uploaded_report.bin"
|
||||
}
|
||||
|
||||
os.MkdirAll(UploadDir, 0755)
|
||||
|
||||
destPath := filepath.Join(UploadDir, cleanName)
|
||||
destFile, err := os.Create(destPath)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to save file to disk", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
if _, err := io.Copy(destFile, file); err != nil {
|
||||
http.Error(w, "Error writing file", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
fmt.Fprintf(w, `{"file_id": "%s"}`, cleanName)
|
||||
}
|
||||
|
||||
// HandleViewReport streams the file to the iframe, converting DOCX if needed
|
||||
func (h *Handler) HandleViewReport(w http.ResponseWriter, r *http.Request) {
|
||||
fileID := r.PathValue("id")
|
||||
cleanName := filepath.Base(fileID)
|
||||
filePath := filepath.Join(UploadDir, cleanName)
|
||||
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
http.Error(w, "Report not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(cleanName))
|
||||
|
||||
if ext == ".pdf" {
|
||||
w.Header().Set("Content-Type", "application/pdf")
|
||||
w.Header().Set("Content-Disposition", "inline; filename="+cleanName)
|
||||
http.ServeFile(w, r, filePath)
|
||||
return
|
||||
}
|
||||
|
||||
if ext == ".docx" {
|
||||
ServeDOCXAsHTML(w, filePath)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, "Unsupported file type. Please upload PDF or DOCX.", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func (h *Handler) HandleImageUpload(w http.ResponseWriter, r *http.Request) {
|
||||
var payload struct {
|
||||
Base64Data string `json:"image_data"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(payload.Base64Data, ",", 2)
|
||||
if len(parts) != 2 {
|
||||
http.Error(w, "Invalid Base64 image format", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ext := ".png"
|
||||
if strings.Contains(parts[0], "jpeg") || strings.Contains(parts[0], "jpg") {
|
||||
ext = ".jpg"
|
||||
}
|
||||
|
||||
rawBase64 := parts[1]
|
||||
imgBytes, err := base64.StdEncoding.DecodeString(rawBase64)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to decode Base64 data", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
randBytes := make([]byte, 8)
|
||||
rand.Read(randBytes)
|
||||
fileName := fmt.Sprintf("img_%x%s", randBytes, ext)
|
||||
|
||||
uploadDir := filepath.Join("data", "testdata", "images")
|
||||
if err := os.MkdirAll(uploadDir, 0755); err != nil {
|
||||
http.Error(w, "Failed to create directory structure", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
savePath := filepath.Join(uploadDir, fileName)
|
||||
if err := os.WriteFile(savePath, imgBytes, 0644); err != nil {
|
||||
http.Error(w, "Failed to save image to disk", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
publicURL := "/testdata/images/" + fileName
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]string{"url": publicURL})
|
||||
}
|
||||
126
pkg/report/reports_test.go
Normal file
126
pkg/report/reports_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/datastore"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func setupTestReport(t *testing.T) (*Handler, *sql.DB) {
|
||||
db := datastore.InitDB(":memory:")
|
||||
store := datastore.NewSQLiteStore(db)
|
||||
return NewHandler(store), db
|
||||
}
|
||||
|
||||
func GetVIPCookie(store domain.Store) *http.Cookie {
|
||||
user, err := store.GetUserByEmail(context.Background(), "vip@RiskRancher.com")
|
||||
if err != nil {
|
||||
user, _ = store.CreateUser(context.Background(), "vip@RiskRancher.com", "Test VIP", "hash", "Sheriff")
|
||||
}
|
||||
|
||||
store.CreateSession(context.Background(), "vip_token_999", user.ID, time.Now().Add(1*time.Hour))
|
||||
return &http.Cookie{Name: "session_token", Value: "vip_token_999"}
|
||||
}
|
||||
|
||||
func TestUploadAndViewReports(t *testing.T) {
|
||||
h, db := setupTestReport(t)
|
||||
defer db.Close()
|
||||
|
||||
t.Run("1. Test PDF Upload and View", func(t *testing.T) {
|
||||
body := new(bytes.Buffer)
|
||||
writer := multipart.NewWriter(body)
|
||||
part, _ := writer.CreateFormFile("file", "test_report.pdf")
|
||||
part.Write([]byte("%PDF-1.4 Fake PDF Content"))
|
||||
writer.Close()
|
||||
|
||||
reqUp := httptest.NewRequest(http.MethodPost, "/api/reports/upload", body)
|
||||
reqUp.AddCookie(GetVIPCookie(h.Store))
|
||||
reqUp.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
rrUp := httptest.NewRecorder()
|
||||
h.HandleUploadReport(rrUp, reqUp)
|
||||
|
||||
reqView := httptest.NewRequest(http.MethodGet, "/api/reports/view/test_report.pdf", nil)
|
||||
reqView.AddCookie(GetVIPCookie(h.Store))
|
||||
reqView.SetPathValue("id", "test_report.pdf")
|
||||
rrView := httptest.NewRecorder()
|
||||
h.HandleViewReport(rrView, reqView)
|
||||
|
||||
if rrView.Code != http.StatusOK {
|
||||
t.Fatalf("Expected 200 OK for PDF View, got %d", rrView.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("2. Test DOCX to HTML", func(t *testing.T) {
|
||||
buf := new(bytes.Buffer)
|
||||
zipWriter := zip.NewWriter(buf)
|
||||
docWriter, _ := zipWriter.Create("word/document.xml")
|
||||
docWriter.Write([]byte(`<w:document><w:body><w:p><w:r><w:t>Cross-Site Scripting</w:t></w:r></w:p></w:body></w:document>`))
|
||||
zipWriter.Close()
|
||||
|
||||
body := new(bytes.Buffer)
|
||||
writer := multipart.NewWriter(body)
|
||||
part, _ := writer.CreateFormFile("file", "fake_pentest.docx")
|
||||
part.Write(buf.Bytes())
|
||||
writer.Close()
|
||||
|
||||
reqUp := httptest.NewRequest(http.MethodPost, "/api/reports/upload", body)
|
||||
reqUp.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
rrUp := httptest.NewRecorder()
|
||||
h.HandleUploadReport(rrUp, reqUp)
|
||||
|
||||
reqView := httptest.NewRequest(http.MethodGet, "/api/reports/view/fake_pentest.docx", nil)
|
||||
reqView.SetPathValue("id", "fake_pentest.docx")
|
||||
rrView := httptest.NewRecorder()
|
||||
h.HandleViewReport(rrView, reqView)
|
||||
|
||||
if !strings.Contains(rrView.Body.String(), "Cross-Site Scripting") {
|
||||
t.Errorf("DOCX-to-HTML failed. Body: %s", rrView.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDraftQueueLifecycle(t *testing.T) {
|
||||
h, db := setupTestReport(t)
|
||||
defer db.Close()
|
||||
|
||||
reportID := "report-uuid-123.pdf"
|
||||
|
||||
// Save Draft
|
||||
draftPayload := []byte(`{"title": "SQLi", "severity": "High", "description": "Page 4"}`)
|
||||
reqPost := httptest.NewRequest(http.MethodPost, "/api/drafts/report/"+reportID, bytes.NewBuffer(draftPayload))
|
||||
reqPost.SetPathValue("id", reportID)
|
||||
rrPost := httptest.NewRecorder()
|
||||
h.HandleSaveDraft(rrPost, reqPost)
|
||||
|
||||
if rrPost.Code >= 400 {
|
||||
t.Fatalf("Failed to save draft! HTTP Code: %d, Error: %s", rrPost.Code, rrPost.Body.String())
|
||||
}
|
||||
|
||||
reqGet := httptest.NewRequest(http.MethodGet, "/api/drafts/report/"+reportID, nil)
|
||||
reqGet.SetPathValue("id", reportID)
|
||||
rrGet := httptest.NewRecorder()
|
||||
h.HandleGetDrafts(rrGet, reqGet)
|
||||
|
||||
var drafts []domain.DraftTicket
|
||||
json.NewDecoder(rrGet.Body).Decode(&drafts)
|
||||
if len(drafts) != 1 || drafts[0].Title != "SQLi" {
|
||||
t.Fatalf("Draft GET mismatch")
|
||||
}
|
||||
|
||||
// Delete Draft
|
||||
reqDel := httptest.NewRequest(http.MethodDelete, "/api/drafts/1", nil)
|
||||
reqDel.SetPathValue("draft_id", "1")
|
||||
rrDel := httptest.NewRecorder()
|
||||
h.HandleDeleteDraft(rrDel, reqDel)
|
||||
}
|
||||
BIN
pkg/report/testdata/fake_pentest.docx
vendored
Normal file
BIN
pkg/report/testdata/fake_pentest.docx
vendored
Normal file
Binary file not shown.
1
pkg/report/testdata/test_report.pdf
vendored
Normal file
1
pkg/report/testdata/test_report.pdf
vendored
Normal file
@@ -0,0 +1 @@
|
||||
%PDF-1.4 Fake PDF Content
|
||||
34
pkg/server/app.go
Normal file
34
pkg/server/app.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/sla"
|
||||
)
|
||||
|
||||
type App struct {
|
||||
Store domain.Store
|
||||
Router *http.ServeMux
|
||||
Auth domain.Authenticator
|
||||
SLA domain.SLACalculator
|
||||
}
|
||||
|
||||
type FreeAuth struct{}
|
||||
|
||||
func (f *FreeAuth) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// In the OSS version, we just pass the request to the next handler for now.
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// NewApp creates a Risk Rancher Core application with OSS defaults.
|
||||
func NewApp(store domain.Store) *App {
|
||||
return &App{
|
||||
Store: store,
|
||||
Router: http.NewServeMux(),
|
||||
Auth: &FreeAuth{},
|
||||
SLA: sla.NewSLACalculator(),
|
||||
}
|
||||
}
|
||||
116
pkg/server/routes.go
Normal file
116
pkg/server/routes.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/adapters"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/admin"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/analytics"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/auth"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/ingest"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/report"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/tickets"
|
||||
"epigas.gitea.cloud/RiskRancher/core/ui"
|
||||
)
|
||||
|
||||
func RegisterRoutes(app *App) {
|
||||
|
||||
authH := auth.NewHandler(app.Store)
|
||||
adminH := admin.NewHandler(app.Store)
|
||||
ticketH := tickets.NewHandler(app.Store)
|
||||
ingestH := ingest.NewHandler(app.Store)
|
||||
adapterH := adapters.NewHandler(app.Store)
|
||||
reportH := report.NewHandler(app.Store)
|
||||
analyticsH := analytics.NewHandler(app.Store)
|
||||
|
||||
protected := func(h http.HandlerFunc) http.Handler {
|
||||
return authH.RequireAuth(http.HandlerFunc(h))
|
||||
}
|
||||
protectedUI := func(h http.HandlerFunc) http.Handler {
|
||||
return authH.RequireUIAuth(http.HandlerFunc(h))
|
||||
}
|
||||
sheriffOnly := func(h http.HandlerFunc) http.Handler {
|
||||
return authH.RequireAuth(authH.RequireRole("Sheriff")(http.HandlerFunc(h)))
|
||||
}
|
||||
adminOnly := func(h http.HandlerFunc) http.Handler {
|
||||
return authH.RequireAuth(authH.RequireAnyRole("Sheriff", "Wrangler")(http.HandlerFunc(h)))
|
||||
}
|
||||
|
||||
// =========================================================
|
||||
// PUBLIC ROUTES
|
||||
// =========================================================
|
||||
app.Router.Handle("GET /login", ui.HandleLoginUI())
|
||||
app.Router.Handle("GET /register", ui.HandleRegisterUI())
|
||||
|
||||
app.Router.HandleFunc("POST /api/auth/register", authH.HandleRegister)
|
||||
app.Router.HandleFunc("POST /api/auth/login", authH.HandleLogin)
|
||||
app.Router.HandleFunc("POST /api/auth/logout", authH.HandleLogout)
|
||||
|
||||
// =========================================================
|
||||
// PROTECTED ROUTES
|
||||
// =========================================================
|
||||
app.Router.Handle("GET /api/wranglers", protected(adminH.HandleGetWranglers))
|
||||
app.Router.Handle("GET /", http.RedirectHandler("/dashboard", http.StatusSeeOther))
|
||||
app.Router.Handle("GET /dashboard", protectedUI(ui.HandleDashboard(app.Store)))
|
||||
|
||||
// Core Tickets
|
||||
app.Router.Handle("GET /api/tickets", protected(ticketH.HandleGetTickets))
|
||||
app.Router.Handle("POST /api/tickets", protected(ticketH.HandleCreateTicket))
|
||||
app.Router.Handle("PATCH /api/tickets/{id}", protected(ticketH.HandleUpdateTicket))
|
||||
|
||||
// Ingestion
|
||||
app.Router.Handle("POST /api/ingest", protected(ingestH.HandleIngest))
|
||||
app.Router.Handle("POST /api/ingest/csv", protected(ingestH.HandleCSVIngest))
|
||||
app.Router.Handle("POST /api/ingest/{name}", protected(adapterH.HandleAdapterIngest))
|
||||
|
||||
// Adapters & Configuration
|
||||
app.Router.Handle("GET /api/adapters", protected(adapterH.HandleGetAdapters))
|
||||
app.Router.Handle("GET /api/config", protected(adminH.HandleGetConfig))
|
||||
|
||||
// Analytics
|
||||
app.Router.Handle("GET /api/analytics/summary", protected(analyticsH.HandleGetAnalyticsSummary))
|
||||
|
||||
// Pentest Reports & Drafts (PDF PARSER - Free Lead Magnet!)
|
||||
app.Router.Handle("POST /api/reports/upload", protected(reportH.HandleUploadReport))
|
||||
app.Router.Handle("GET /api/reports/view/{id}", protected(reportH.HandleViewReport))
|
||||
app.Router.Handle("POST /api/drafts/report/{id}", protected(reportH.HandleSaveDraft))
|
||||
app.Router.Handle("GET /api/drafts/report/{id}", protected(reportH.HandleGetDrafts))
|
||||
app.Router.Handle("DELETE /api/drafts/{draft_id}", protected(reportH.HandleDeleteDraft))
|
||||
|
||||
// =========================================================
|
||||
// SHERIFF & ADMIN ONLY
|
||||
// =========================================================
|
||||
|
||||
app.Router.Handle("GET /admin", sheriffOnly(ui.HandleAdminDashboard(app.Store)))
|
||||
|
||||
app.Router.Handle("POST /api/adapters", adminOnly(adapterH.HandleCreateAdapter))
|
||||
app.Router.Handle("DELETE /api/adapters/{id}", adminOnly(adapterH.HandleDeleteAdapter))
|
||||
|
||||
app.Router.Handle("GET /api/admin/export", sheriffOnly(adminH.HandleExportState))
|
||||
app.Router.Handle("GET /api/admin/check-updates", sheriffOnly(adminH.HandleCheckUpdates))
|
||||
app.Router.Handle("POST /api/admin/shutdown", sheriffOnly(adminH.HandleShutdown))
|
||||
|
||||
app.Router.Handle("GET /api/admin/users", adminOnly(adminH.HandleGetUsers))
|
||||
app.Router.Handle("POST /api/admin/users", sheriffOnly(adminH.HandleCreateUser))
|
||||
app.Router.Handle("PATCH /api/admin/users/{id}/reset-password", sheriffOnly(adminH.HandleAdminResetPassword))
|
||||
app.Router.Handle("PATCH /api/admin/users/{id}/role", sheriffOnly(adminH.HandleUpdateUserRole))
|
||||
app.Router.Handle("DELETE /api/admin/users/{id}", sheriffOnly(adminH.HandleDeactivateUser))
|
||||
app.Router.Handle("GET /api/admin/logs", sheriffOnly(adminH.HandleGetLogs))
|
||||
|
||||
app.Router.Handle("GET /static/", ui.StaticHandler())
|
||||
|
||||
// =========================================================
|
||||
// UI EXTENSIONS
|
||||
// =========================================================
|
||||
|
||||
app.Router.Handle("GET /ingest", protectedUI(ui.HandleIngestUI(app.Store)))
|
||||
app.Router.Handle("GET /admin/adapters/new", protectedUI(ui.HandleAdapterBuilderUI(app.Store)))
|
||||
|
||||
// Word Docx Parser
|
||||
app.Router.Handle("GET /reports/parser/{id}", protectedUI(ui.HandleParserUI(app.Store)))
|
||||
app.Router.Handle("POST /api/reports/promote/{id}", protected(reportH.HandlePromoteDrafts))
|
||||
app.Router.Handle("GET /reports/upload", protectedUI(ui.HandlePentestUploadUI(app.Store)))
|
||||
app.Router.Handle("PUT /api/drafts/{id}", protected(reportH.HandleUpdateDraft))
|
||||
app.Router.Handle("POST /api/images/upload", protected(reportH.HandleImageUpload))
|
||||
app.Router.Handle("GET /uploads/", http.StripPrefix("/testdata/", http.FileServer(http.Dir("./data/testdata"))))
|
||||
}
|
||||
127
pkg/sla/sla.go
Normal file
127
pkg/sla/sla.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package sla
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
// DefaultSLACalculator implements the SLACalculator interface
|
||||
type DefaultSLACalculator struct {
|
||||
Timezone string
|
||||
BusinessStart int
|
||||
BusinessEnd int
|
||||
Holidays map[string]bool
|
||||
}
|
||||
|
||||
// NewSLACalculator returns the interface
|
||||
func NewSLACalculator() domain.SLACalculator {
|
||||
return &DefaultSLACalculator{
|
||||
Timezone: "UTC",
|
||||
BusinessStart: 9,
|
||||
BusinessEnd: 17,
|
||||
Holidays: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateDueDate for the finding based on SLA
|
||||
func (c *DefaultSLACalculator) CalculateDueDate(severity string) *time.Time {
|
||||
var days int
|
||||
switch severity {
|
||||
case "Critical":
|
||||
days = 3
|
||||
case "High":
|
||||
days = 14
|
||||
case "Medium":
|
||||
days = 30
|
||||
case "Low":
|
||||
days = 90
|
||||
default:
|
||||
days = 30
|
||||
}
|
||||
|
||||
loc, err := time.LoadLocation(c.Timezone)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Invalid timezone '%s', falling back to UTC", c.Timezone)
|
||||
loc = time.UTC
|
||||
}
|
||||
|
||||
nowLocal := time.Now().In(loc)
|
||||
dueDate := c.AddBusinessDays(nowLocal, days)
|
||||
return &dueDate
|
||||
}
|
||||
|
||||
// AddBusinessDays for working days not weekends and some holidays
|
||||
func (c *DefaultSLACalculator) AddBusinessDays(start time.Time, businessDays int) time.Time {
|
||||
current := start
|
||||
added := 0
|
||||
for added < businessDays {
|
||||
current = current.AddDate(0, 0, 1)
|
||||
weekday := current.Weekday()
|
||||
dateStr := current.Format("2006-01-02")
|
||||
if weekday != time.Saturday && weekday != time.Sunday && !c.Holidays[dateStr] {
|
||||
added++
|
||||
}
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
// CalculateTrueSLAHours based on the time of action for ticket
|
||||
func (c *DefaultSLACalculator) CalculateTrueSLAHours(ctx context.Context, ticketID int, store domain.Store) (float64, error) {
|
||||
appConfig, err := store.GetAppConfig(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
ticket, err := store.GetTicketByID(ctx, ticketID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
end := time.Now()
|
||||
if ticket.PatchedAt != nil {
|
||||
end = *ticket.PatchedAt
|
||||
}
|
||||
|
||||
totalActiveBusinessHours := c.calculateBusinessHoursBetween(ticket.CreatedAt, end, appConfig)
|
||||
return totalActiveBusinessHours, nil
|
||||
}
|
||||
|
||||
// calculateBusinessHoursBetween calculates strict working hours between two timestamps
|
||||
func (c *DefaultSLACalculator) calculateBusinessHoursBetween(start, end time.Time, config domain.AppConfig) float64 {
|
||||
loc, _ := time.LoadLocation(config.Timezone)
|
||||
start = start.In(loc)
|
||||
end = end.In(loc)
|
||||
|
||||
if start.After(end) {
|
||||
return 0
|
||||
}
|
||||
|
||||
var activeHours float64
|
||||
current := start
|
||||
|
||||
for current.Before(end) {
|
||||
nextHour := current.Add(time.Hour)
|
||||
if nextHour.After(end) {
|
||||
nextHour = end
|
||||
}
|
||||
|
||||
weekday := current.Weekday()
|
||||
dateStr := current.Format("2006-01-02")
|
||||
hour := current.Hour()
|
||||
|
||||
isWeekend := weekday == time.Saturday || weekday == time.Sunday
|
||||
isHoliday := c.Holidays[dateStr]
|
||||
isBusinessHour := hour >= config.BusinessStart && hour < config.BusinessEnd
|
||||
|
||||
if !isWeekend && !isHoliday && isBusinessHour {
|
||||
activeHours += nextHour.Sub(current).Hours()
|
||||
}
|
||||
|
||||
current = nextHour
|
||||
}
|
||||
|
||||
return activeHours
|
||||
}
|
||||
116
pkg/sla/sla_test.go
Normal file
116
pkg/sla/sla_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package sla_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// GetSLAPolicy simulates the core engine function that fetches SLA rules
|
||||
func GetSLAPolicy(db *sql.DB, domain string, severity string) (daysToRemediate int, maxExtensions int, err error) {
|
||||
query := `SELECT days_to_remediate, max_extensions FROM sla_policies WHERE domain = ? AND severity = ?`
|
||||
err = db.QueryRow(query, domain, severity).Scan(&daysToRemediate, &maxExtensions)
|
||||
return daysToRemediate, maxExtensions, err
|
||||
}
|
||||
|
||||
// setupTestDB spins up an isolated, in-memory database for testing
|
||||
func setupTestDB(t *testing.T) *sql.DB {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open test database: %v", err)
|
||||
}
|
||||
|
||||
schema := `
|
||||
CREATE TABLE domains (name TEXT PRIMARY KEY);
|
||||
CREATE TABLE sla_policies (
|
||||
domain TEXT NOT NULL,
|
||||
severity TEXT NOT NULL,
|
||||
days_to_remediate INTEGER NOT NULL,
|
||||
max_extensions INTEGER NOT NULL DEFAULT 3,
|
||||
PRIMARY KEY (domain, severity)
|
||||
);
|
||||
INSERT INTO domains (name) VALUES ('Vulnerability'), ('Privacy'), ('Incident');
|
||||
INSERT INTO sla_policies (domain, severity, days_to_remediate, max_extensions) VALUES
|
||||
('Vulnerability', 'Critical', 14, 1),
|
||||
('Vulnerability', 'High', 30, 2),
|
||||
('Privacy', 'Critical', 3, 0),
|
||||
('Incident', 'Critical', 1, 0);
|
||||
`
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
t.Fatalf("Failed to execute test schema: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func TestSLAEngine(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
severity string
|
||||
expectDays int
|
||||
expectExtensions int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "VM Critical (Standard)",
|
||||
domain: "Vulnerability",
|
||||
severity: "Critical",
|
||||
expectDays: 14,
|
||||
expectExtensions: 1,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Privacy Critical (Strict 72-hour, No Extensions)",
|
||||
domain: "Privacy",
|
||||
severity: "Critical",
|
||||
expectDays: 3,
|
||||
expectExtensions: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Incident Critical (24-hour, No Extensions)",
|
||||
domain: "Incident",
|
||||
severity: "Critical",
|
||||
expectDays: 1,
|
||||
expectExtensions: 0,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Unknown Domain (Should Fail)",
|
||||
domain: "PhysicalSecurity",
|
||||
severity: "Critical",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Unknown Severity (Should Fail)",
|
||||
domain: "Vulnerability",
|
||||
severity: "SuperCritical",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
days, extensions, err := GetSLAPolicy(db, tt.domain, tt.severity)
|
||||
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Fatalf("expected error: %v, got: %v", tt.expectError, err)
|
||||
}
|
||||
|
||||
if tt.expectError {
|
||||
return
|
||||
}
|
||||
|
||||
if days != tt.expectDays {
|
||||
t.Errorf("expected %d days, got %d", tt.expectDays, days)
|
||||
}
|
||||
if extensions != tt.expectExtensions {
|
||||
t.Errorf("expected %d max extensions, got %d", tt.expectExtensions, extensions)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
15
pkg/tickets/handler.go
Normal file
15
pkg/tickets/handler.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package tickets
|
||||
|
||||
import (
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
// Handler encapsulates all Ticket-related HTTP logic
|
||||
type Handler struct {
|
||||
Store domain.Store
|
||||
}
|
||||
|
||||
// NewHandler creates a new Tickets Handler
|
||||
func NewHandler(store domain.Store) *Handler {
|
||||
return &Handler{Store: store}
|
||||
}
|
||||
73
pkg/tickets/handlers_test.go
Normal file
73
pkg/tickets/handlers_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package tickets
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/datastore"
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
func setupTestTickets(t *testing.T) (*Handler, *sql.DB) {
|
||||
db := datastore.InitDB(":memory:")
|
||||
store := datastore.NewSQLiteStore(db)
|
||||
return NewHandler(store), db
|
||||
}
|
||||
|
||||
// GetVIPCookie creates a dummy Sheriff user and an active session,
|
||||
func GetVIPCookie(store domain.Store) *http.Cookie {
|
||||
|
||||
user, err := store.GetUserByEmail(context.Background(), "vip_test@RiskRancher.com")
|
||||
if err != nil {
|
||||
user, _ = store.CreateUser(context.Background(), "vip_test@RiskRancher.com", "Test VIP", "hash", "Sheriff")
|
||||
}
|
||||
|
||||
token := "vip_test_token_999"
|
||||
store.CreateSession(context.Background(), token, user.ID, time.Now().Add(1*time.Hour))
|
||||
|
||||
return &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: token,
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSingleTicket(t *testing.T) {
|
||||
app, db := setupTestTickets(t)
|
||||
defer db.Close()
|
||||
|
||||
payload := []byte(`{
|
||||
"title": "Manual Pentest Finding: XSS",
|
||||
"description": "Found reflected XSS on the search page.",
|
||||
"recommended_remediation": "Sanitize user input.",
|
||||
"severity": "High"
|
||||
}`)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/tickets", bytes.NewBuffer(payload))
|
||||
req.AddCookie(GetVIPCookie(app.Store))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
app.HandleCreateTicket(rr, req)
|
||||
|
||||
if status := rr.Code; status != http.StatusCreated {
|
||||
t.Fatalf("Expected status %v, got %v. Body: %s", http.StatusCreated, status, rr.Body.String())
|
||||
}
|
||||
|
||||
var createdTicket domain.Ticket
|
||||
if err := json.NewDecoder(rr.Body).Decode(&createdTicket); err != nil {
|
||||
t.Fatalf("Failed to decode JSON response: %v", err)
|
||||
}
|
||||
|
||||
if createdTicket.ID == 0 {
|
||||
t.Errorf("Expected database to generate an ID")
|
||||
}
|
||||
if createdTicket.DedupeHash == "" {
|
||||
t.Errorf("Expected engine to generate a dedupe hash")
|
||||
}
|
||||
}
|
||||
74
pkg/tickets/tickets.go
Normal file
74
pkg/tickets/tickets.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package tickets
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"epigas.gitea.cloud/RiskRancher/core/pkg/domain"
|
||||
)
|
||||
|
||||
type InlineUpdateRequest struct {
|
||||
Severity string `json:"severity"`
|
||||
Comment string `json:"comment"`
|
||||
Description string `json:"description"`
|
||||
RecommendedRemediation string `json:"recommended_remediation"`
|
||||
Actor string `json:"actor"`
|
||||
Status string `json:"status"`
|
||||
Assignee string `json:"assignee"`
|
||||
}
|
||||
|
||||
type BulkUpdateRequest struct {
|
||||
TicketIDs []int `json:"ticket_ids"`
|
||||
Status string `json:"status"`
|
||||
Comment string `json:"comment"`
|
||||
Assignee string `json:"assignee"`
|
||||
Actor string `json:"actor"`
|
||||
}
|
||||
|
||||
type MagistrateReviewRequest struct {
|
||||
Action string `json:"action"`
|
||||
Actor string `json:"actor"`
|
||||
Comment string `json:"comment"`
|
||||
ExtensionDays int `json:"extension_days"`
|
||||
}
|
||||
|
||||
func (h *Handler) HandleUpdateTicket(w http.ResponseWriter, r *http.Request) {
|
||||
id, _ := strconv.Atoi(r.PathValue("id"))
|
||||
var req InlineUpdateRequest
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
if err := h.Store.UpdateTicketInline(r.Context(), id, req.Severity, req.Description, req.RecommendedRemediation, req.Comment, req.Actor, req.Status, req.Assignee); err != nil {
|
||||
http.Error(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// HandleGetTickets fetches a list of tickets via the API
|
||||
func (h *Handler) HandleGetTickets(w http.ResponseWriter, r *http.Request) {
|
||||
tickets, err := h.Store.GetTickets(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, "Database error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(tickets)
|
||||
}
|
||||
|
||||
// HandleCreateTicket creates a single ticket via the API
|
||||
func (h *Handler) HandleCreateTicket(w http.ResponseWriter, r *http.Request) {
|
||||
var t domain.Ticket
|
||||
if err := json.NewDecoder(r.Body).Decode(&t); err != nil {
|
||||
http.Error(w, "Invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.Store.CreateTicket(r.Context(), &t); err != nil {
|
||||
http.Error(w, "Failed to create ticket", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(t)
|
||||
}
|
||||
Reference in New Issue
Block a user