package db
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"time"
)
// =============================================================================
// Test Helpers
// =============================================================================
// resetGlobalState resets package-level variables for test isolation.
func resetGlobalState() {
eventLogMu.Lock()
defer eventLogMu.Unlock()
if eventLogFile != nil {
eventLogFile.Close()
eventLogFile = nil
eventLogEnc = nil
}
eventLogConfig = EventLogConfig{}
}
// setupTestDB creates an in-memory DuckDB with a test table.
func setupTestDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("duckdb", "")
if err != nil {
t.Fatalf("Failed to open in-memory DuckDB: %v", err)
}
_, err = db.Exec("CREATE TABLE test_table (id VARCHAR PRIMARY KEY, name VARCHAR, value INTEGER)")
if err != nil {
db.Close()
t.Fatalf("Failed to create test table: %v", err)
}
return db
}
// readEventsFile reads all events from a JSONL file.
func readEventsFile(path string) ([]TransactionEvent, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var events []TransactionEvent
for _, line := range bytes.Split(data, []byte("\n")) {
if len(line) == 0 {
continue
}
var event TransactionEvent
if err := json.Unmarshal(line, &event); err != nil {
return nil, err
}
events = append(events, event)
}
return events, nil
}
// Assertion helpers using standard library
func assertEqual(t *testing.T, expected, actual interface{}, msg ...string) {
t.Helper()
if !reflect.DeepEqual(expected, actual) {
if len(msg) > 0 {
t.Errorf("%s: expected %v, got %v", msg[0], expected, actual)
} else {
t.Errorf("expected %v, got %v", expected, actual)
}
}
}
func assertNotEqual(t *testing.T, expected, actual interface{}, msg ...string) {
t.Helper()
if expected == actual {
if len(msg) > 0 {
t.Errorf("%s: expected %v to not equal %v", msg[0], expected, actual)
} else {
t.Errorf("expected %v to not equal %v", expected, actual)
}
}
}
func assertNil(t *testing.T, value interface{}, msg ...string) {
t.Helper()
if value != nil && !isTypedNil(value) {
if len(msg) > 0 {
t.Errorf("%s: expected nil, got %v", msg[0], value)
} else {
t.Errorf("expected nil, got %v", value)
}
}
}
// isTypedNil checks if a value is a typed nil (e.g., *os.File(nil))
func isTypedNil(v interface{}) bool {
if v == nil {
return true
}
// Use reflection to check for typed nil
rv := reflect.ValueOf(v)
switch rv.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice:
return rv.IsNil()
}
return false
}
func assertNotNil(t *testing.T, value interface{}, msg ...string) {
t.Helper()
if value == nil {
if len(msg) > 0 {
t.Errorf("%s: expected non-nil value", msg[0])
} else {
t.Errorf("expected non-nil value")
}
}
}
func assertTrue(t *testing.T, value bool, msg ...string) {
t.Helper()
if !value {
if len(msg) > 0 {
t.Errorf("%s: expected true, got false", msg[0])
} else {
t.Errorf("expected true, got false")
}
}
}
func assertFalse(t *testing.T, value bool, msg ...string) {
t.Helper()
if value {
if len(msg) > 0 {
t.Errorf("%s: expected false, got true", msg[0])
} else {
t.Errorf("expected false, got true")
}
}
}
func assertError(t *testing.T, err error, msg ...string) {
t.Helper()
if err == nil {
if len(msg) > 0 {
t.Errorf("%s: expected error, got nil", msg[0])
} else {
t.Errorf("expected error, got nil")
}
}
}
func assertNoError(t *testing.T, err error, msg ...string) {
t.Helper()
if err != nil {
if len(msg) > 0 {
t.Errorf("%s: expected no error, got %v", msg[0], err)
} else {
t.Errorf("expected no error, got %v", err)
}
}
}
func assertLen(t *testing.T, expected, actual int, msg ...string) {
t.Helper()
if expected != actual {
if len(msg) > 0 {
t.Errorf("%s: expected length %d, got %d", msg[0], expected, actual)
} else {
t.Errorf("expected length %d, got %d", expected, actual)
}
}
}
func assertContains(t *testing.T, s, substr string, msg ...string) {
t.Helper()
if !strings.Contains(s, substr) {
if len(msg) > 0 {
t.Errorf("%s: expected %q to contain %q", msg[0], s, substr)
} else {
t.Errorf("expected %q to contain %q", s, substr)
}
}
}
func assertGreater(t *testing.T, a, b int64, msg ...string) {
t.Helper()
if a <= b {
if len(msg) > 0 {
t.Errorf("%s: expected %d > %d", msg[0], a, b)
} else {
t.Errorf("expected %d > %d", a, b)
}
}
}
// =============================================================================
// Category 1: Pure Function Tests
// =============================================================================
func TestIsMutation(t *testing.T) {
tests := []struct {
name string
sql string
expected bool
}{
// INSERT variations
{"INSERT uppercase", "INSERT INTO test VALUES (1)", true},
{"INSERT lowercase", "insert into test values (1)", true},
{"INSERT with leading space", " INSERT INTO test VALUES (1)", true},
{"INSERT with leading newline", "\n\tINSERT INTO test VALUES (1)", true},
// Note: SQL with leading comment is not detected as mutation
// because isMutation checks HasPrefix after TrimSpace, and "--" is not INSERT/UPDATE/DELETE
// UPDATE variations
{"UPDATE uppercase", "UPDATE test SET x = 1", true},
{"UPDATE lowercase", "update test set x = 1", true},
{"UPDATE with WHERE", "UPDATE test SET x = 1 WHERE id = 1", true},
// DELETE variations
{"DELETE uppercase", "DELETE FROM test WHERE x = 1", true},
{"DELETE lowercase", "delete from test where x = 1", true},
// SELECT (not mutation)
{"SELECT uppercase", "SELECT * FROM test", false},
{"SELECT lowercase", "select * from test", false},
{"SELECT with WHERE", "SELECT * FROM test WHERE id = 1", false},
// WITH clause (CTE) with mutation
{"CTE with INSERT", "WITH cte AS (SELECT 1) INSERT INTO test SELECT * FROM cte", true},
{"CTE with UPDATE", "WITH cte AS (SELECT 1) UPDATE test SET x = 1", true},
{"CTE with DELETE", "WITH cte AS (SELECT 1) DELETE FROM test", true},
{"CTE lowercase with insert", "with cte as (select 1) insert into test select * from cte", true},
// WITH clause (CTE) without mutation
{"CTE with SELECT only", "WITH cte AS (SELECT 1) SELECT * FROM cte", false},
{"CTE lowercase with select", "with cte as (select 1) select * from cte", false},
// Edge cases
{"empty string", "", false},
{"whitespace only", " ", false},
{"just SELECT keyword", "SELECT", false},
{"just INSERT keyword", "INSERT", true},
{"just UPDATE keyword", "UPDATE", true},
{"just DELETE keyword", "DELETE", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isMutation(tt.sql)
assertEqual(t, tt.expected, result, "isMutation(%q)", tt.sql)
})
}
}
func TestMarshalParam(t *testing.T) {
t.Run("nil", func(t *testing.T) {
result := marshalParam(nil)
assertNil(t, result)
})
t.Run("time.Time", func(t *testing.T) {
tm := time.Date(2026, 2, 18, 14, 30, 0, 0, time.UTC)
result := marshalParam(tm)
assertEqual(t, "2026-02-18T14:30:00Z", result)
})
t.Run("*time.Time nil", func(t *testing.T) {
var tm *time.Time
result := marshalParam(tm)
assertNil(t, result)
})
t.Run("*time.Time with value", func(t *testing.T) {
tm := time.Date(2026, 2, 18, 14, 30, 0, 123456789, time.UTC)
result := marshalParam(&tm)
assertEqual(t, "2026-02-18T14:30:00.123456789Z", result)
})
t.Run("time.Time with nanoseconds", func(t *testing.T) {
tm := time.Date(2026, 2, 18, 14, 30, 0, 999999999, time.UTC)
result := marshalParam(tm)
assertEqual(t, "2026-02-18T14:30:00.999999999Z", result)
})
t.Run("time.Time with timezone", func(t *testing.T) {
loc, _ := time.LoadLocation("Pacific/Auckland")
tm := time.Date(2026, 2, 19, 10, 30, 0, 0, loc)
result := marshalParam(tm)
// Should contain timezone offset
assertContains(t, result.(string), "+13:00")
})
t.Run("string", func(t *testing.T) {
result := marshalParam("hello world")
assertEqual(t, "hello world", result)
})
t.Run("*string nil", func(t *testing.T) {
var s *string
result := marshalParam(s)
assertNil(t, result)
})
t.Run("*string with value", func(t *testing.T) {
s := "hello"
result := marshalParam(&s)
assertEqual(t, "hello", result)
})
t.Run("int types", func(t *testing.T) {
assertEqual(t, int(42), marshalParam(int(42)))
assertEqual(t, int8(42), marshalParam(int8(42)))
assertEqual(t, int16(42), marshalParam(int16(42)))
assertEqual(t, int32(42), marshalParam(int32(42)))
assertEqual(t, int64(42), marshalParam(int64(42)))
assertEqual(t, uint(42), marshalParam(uint(42)))
assertEqual(t, uint8(42), marshalParam(uint8(42)))
assertEqual(t, uint16(42), marshalParam(uint16(42)))
assertEqual(t, uint32(42), marshalParam(uint32(42)))
assertEqual(t, uint64(42), marshalParam(uint64(42)))
})
t.Run("negative int", func(t *testing.T) {
assertEqual(t, int(-42), marshalParam(int(-42)))
assertEqual(t, int64(-42), marshalParam(int64(-42)))
})
t.Run("float types", func(t *testing.T) {
assertEqual(t, float32(3.14), marshalParam(float32(3.14)))
assertEqual(t, float64(3.14), marshalParam(float64(3.14)))
})
t.Run("bool", func(t *testing.T) {
assertEqual(t, true, marshalParam(true))
assertEqual(t, false, marshalParam(false))
})
t.Run("[]byte", func(t *testing.T) {
b := []byte("hello")
result := marshalParam(b)
assertEqual(t, b, result)
})
t.Run("unknown type", func(t *testing.T) {
type MyType struct{ X int }
result := marshalParam(MyType{X: 42})
// fmt.Sprintf("%v", MyType{X: 42}) produces "{42}"
assertContains(t, result.(string), "42")
})
t.Run("slice", func(t *testing.T) {
s := []string{"a", "b", "c"}
result := marshalParam(s)
assertEqual(t, "[a b c]", result)
})
t.Run("map", func(t *testing.T) {
m := map[string]int{"a": 1}
result := marshalParam(m)
assertContains(t, result.(string), "a")
})
}
func TestQueryRecordMarshalJSON(t *testing.T) {
t.Run("basic types", func(t *testing.T) {
qr := QueryRecord{
SQL: "INSERT INTO test VALUES (?, ?)",
Parameters: []interface{}{"id123", 42},
}
data, err := json.Marshal(qr)
assertNoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(data, &result)
assertNoError(t, err)
assertEqual(t, "INSERT INTO test VALUES (?, ?)", result["sql"])
params := result["parameters"].([]interface{})
assertEqual(t, "id123", params[0])
assertEqual(t, 42.0, params[1]) // JSON numbers are floats
})
t.Run("with time.Time", func(t *testing.T) {
tm := time.Date(2026, 2, 18, 14, 30, 0, 0, time.UTC)
qr := QueryRecord{
SQL: "INSERT INTO test VALUES (?)",
Parameters: []interface{}{tm},
}
data, err := json.Marshal(qr)
assertNoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(data, &result)
assertNoError(t, err)
params := result["parameters"].([]interface{})
assertEqual(t, "2026-02-18T14:30:00Z", params[0])
})
t.Run("with nil parameter", func(t *testing.T) {
qr := QueryRecord{
SQL: "INSERT INTO test VALUES (?)",
Parameters: []interface{}{nil},
}
data, err := json.Marshal(qr)
assertNoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(data, &result)
assertNoError(t, err)
params := result["parameters"].([]interface{})
assertNil(t, params[0])
})
t.Run("empty parameters", func(t *testing.T) {
qr := QueryRecord{
SQL: "SELECT 1",
Parameters: []interface{}{},
}
data, err := json.Marshal(qr)
assertNoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(data, &result)
assertNoError(t, err)
params := result["parameters"].([]interface{})
assertLen(t, 0, len(params))
})
t.Run("multiple param types", func(t *testing.T) {
qr := QueryRecord{
SQL: "INSERT INTO test VALUES (?, ?, ?, ?, ?)",
Parameters: []interface{}{"string", 42, true, nil, 3.14},
}
data, err := json.Marshal(qr)
assertNoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(data, &result)
assertNoError(t, err)
params := result["parameters"].([]interface{})
assertLen(t, 5, len(params))
assertEqual(t, "string", params[0])
assertEqual(t, 42.0, params[1])
assertEqual(t, true, params[2])
assertNil(t, params[3])
assertEqual(t, 3.14, params[4])
})
t.Run("special characters in SQL", func(t *testing.T) {
qr := QueryRecord{
SQL: "INSERT INTO test VALUES ('O''Brien', \"test\")",
Parameters: []interface{}{},
}
data, err := json.Marshal(qr)
assertNoError(t, err)
// Verify JSON is valid
var result map[string]interface{}
err = json.Unmarshal(data, &result)
assertNoError(t, err)
assertContains(t, result["sql"].(string), "O''Brien")
})
t.Run("unicode in parameters", func(t *testing.T) {
qr := QueryRecord{
SQL: "INSERT INTO test VALUES (?)",
Parameters: []interface{}{"日本語 🎵"},
}
data, err := json.Marshal(qr)
assertNoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(data, &result)
assertNoError(t, err)
params := result["parameters"].([]interface{})
assertEqual(t, "日本語 🎵", params[0])
})
}
// =============================================================================
// Category 2: Global State Tests
// =============================================================================
func TestSetEventLogConfig(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("set enabled with path", func(t *testing.T) {
resetGlobalState()
cfg := EventLogConfig{
Enabled: true,
Path: "/tmp/test.jsonl",
}
SetEventLogConfig(cfg)
got := GetEventLogConfig()
assertTrue(t, got.Enabled)
assertEqual(t, "/tmp/test.jsonl", got.Path)
})
t.Run("set disabled", func(t *testing.T) {
resetGlobalState()
cfg := EventLogConfig{
Enabled: false,
Path: "/tmp/test.jsonl",
}
SetEventLogConfig(cfg)
got := GetEventLogConfig()
assertFalse(t, got.Enabled)
})
t.Run("change path while file open", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
path1 := filepath.Join(tmpDir, "events1.jsonl")
path2 := filepath.Join(tmpDir, "events2.jsonl")
// Set first config and open file
SetEventLogConfig(EventLogConfig{Enabled: true, Path: path1})
ensureEventLogFile()
assertNotNil(t, eventLogFile)
// Change path - should close first file
SetEventLogConfig(EventLogConfig{Enabled: true, Path: path2})
// File handle should be nil (will reopen on next ensure)
// Note: SetEventLogConfig closes the file, sets eventLogFile = nil
assertNil(t, eventLogFile)
})
}
func TestGetEventLogConfig(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("default state", func(t *testing.T) {
resetGlobalState()
got := GetEventLogConfig()
assertFalse(t, got.Enabled)
assertEqual(t, "", got.Path)
})
t.Run("after set", func(t *testing.T) {
resetGlobalState()
SetEventLogConfig(EventLogConfig{Enabled: true, Path: "/test/path.jsonl"})
got := GetEventLogConfig()
assertTrue(t, got.Enabled)
assertEqual(t, "/test/path.jsonl", got.Path)
})
}
func TestCloseEventLog(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("close with no file", func(t *testing.T) {
resetGlobalState()
err := CloseEventLog()
assertNoError(t, err)
})
t.Run("close with open file", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
ensureEventLogFile()
assertNotNil(t, eventLogFile)
err := CloseEventLog()
assertNoError(t, err)
// Verify state is reset
assertFalse(t, eventLogConfig.Enabled)
assertNil(t, eventLogFile)
assertNil(t, eventLogEnc)
})
t.Run("double close", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
ensureEventLogFile()
err := CloseEventLog()
assertNoError(t, err)
// Second close should not panic
err = CloseEventLog()
assertNoError(t, err)
})
}
// =============================================================================
// Category 3: Integration Tests
// =============================================================================
func TestBeginLoggedTx(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("creates transaction", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := BeginLoggedTx(context.Background(), db, "test_tool")
assertNoError(t, err)
assertNotNil(t, tx)
assertEqual(t, "test_tool", tx.toolName)
assertNotNil(t, tx.queries)
assertLen(t, 0, len(tx.queries))
assertFalse(t, tx.startTime.IsZero())
tx.Rollback()
})
t.Run("empty tool name is allowed", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := BeginLoggedTx(context.Background(), db, "")
assertNoError(t, err)
assertNotNil(t, tx)
assertEqual(t, "", tx.toolName)
tx.Rollback()
})
t.Run("initial state is clean", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
assertLen(t, 0, len(tx.queries))
assertFalse(t, tx.startTime.IsZero())
// Verify startTime is recent (within last second)
elapsed := time.Since(tx.startTime)
assertTrue(t, elapsed < time.Second, "startTime should be recent")
tx.Rollback()
})
}
func TestLoggedTx_ExecContext(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("records INSERT", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
_, err := tx.ExecContext(context.Background(),
"INSERT INTO test_table VALUES (?, ?, ?)", "id1", "name1", 42)
assertNoError(t, err)
assertLen(t, 1, len(tx.queries))
assertContains(t, tx.queries[0].SQL, "INSERT")
assertLen(t, 3, len(tx.queries[0].Parameters))
assertEqual(t, "id1", tx.queries[0].Parameters[0])
})
t.Run("records UPDATE", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "id2", "name2", 1)
_, err := tx.ExecContext(context.Background(),
"UPDATE test_table SET value = ? WHERE id = ?", 100, "id2")
assertNoError(t, err)
assertLen(t, 2, len(tx.queries))
assertContains(t, tx.queries[1].SQL, "UPDATE")
tx.Rollback()
})
t.Run("records DELETE", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "id3", "name3", 1)
_, err := tx.ExecContext(context.Background(),
"DELETE FROM test_table WHERE id = ?", "id3")
assertNoError(t, err)
assertLen(t, 2, len(tx.queries))
assertContains(t, tx.queries[1].SQL, "DELETE")
tx.Rollback()
})
t.Run("does not record SELECT", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "id4", "name4", 1)
// SELECT should not be recorded
tx.QueryRowContext(context.Background(), "SELECT * FROM test_table WHERE id = ?", "id4")
assertLen(t, 1, len(tx.queries)) // Only the INSERT
tx.Rollback()
})
t.Run("does not record failed execution", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
// This will fail (table doesn't exist)
_, err := tx.ExecContext(context.Background(),
"INSERT INTO nonexistent_table VALUES (?)", "x")
assertError(t, err)
assertLen(t, 0, len(tx.queries)) // Failed query not recorded
})
t.Run("multiple executions recorded in order", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "id1", "name1", 1)
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "id2", "name2", 2)
tx.ExecContext(context.Background(), "UPDATE test_table SET value = ? WHERE id = ?", 99, "id1")
assertLen(t, 3, len(tx.queries))
assertContains(t, tx.queries[0].SQL, "INSERT")
assertContains(t, tx.queries[1].SQL, "INSERT")
assertContains(t, tx.queries[2].SQL, "UPDATE")
})
t.Run("parameters stored correctly", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
tx.ExecContext(context.Background(),
"INSERT INTO test_table VALUES (?, ?, ?)", "param_id", "param_name", 123)
assertLen(t, 3, len(tx.queries[0].Parameters))
assertEqual(t, "param_id", tx.queries[0].Parameters[0])
assertEqual(t, "param_name", tx.queries[0].Parameters[1])
assertEqual(t, 123, tx.queries[0].Parameters[2])
})
}
func TestLoggedTx_Exec(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("INSERT without context", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
_, err := tx.Exec("INSERT INTO test_table VALUES (?, ?, ?)", "id1", "name1", 42)
assertNoError(t, err)
assertLen(t, 1, len(tx.queries))
assertContains(t, tx.queries[0].SQL, "INSERT")
})
}
func TestLoggedTx_Commit(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("writes event to file on commit", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test_tool")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "id1", "name1", 42)
err := tx.Commit()
assertNoError(t, err)
// Verify event was written
events, err := readEventsFile(logPath)
assertNoError(t, err)
assertLen(t, 1, len(events))
assertNotNil(t, events[0].ID)
assertLen(t, 21, len(events[0].ID))
assertEqual(t, "test_tool", events[0].Tool)
assertLen(t, 1, len(events[0].Queries))
assertTrue(t, events[0].Success)
// Duration may be 0 for very fast transactions
assertTrue(t, events[0].Duration >= 0)
})
t.Run("does not write when logging disabled", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: false, Path: logPath})
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test_tool")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "id2", "name2", 1)
err := tx.Commit()
assertNoError(t, err)
// No file should be created
_, err = os.Stat(logPath)
assertTrue(t, os.IsNotExist(err), "file should not exist")
})
t.Run("does not write when no mutations", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test_tool")
// No mutations, just reads
tx.QueryRowContext(context.Background(), "SELECT 1")
err := tx.Commit()
assertNoError(t, err)
// No file should be created
_, err = os.Stat(logPath)
assertTrue(t, os.IsNotExist(err), "file should not exist")
})
t.Run("multiple mutations in single event", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "multi_test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "m1", "name1", 1)
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "m2", "name2", 2)
tx.ExecContext(context.Background(), "UPDATE test_table SET value = ? WHERE id = ?", 99, "m1")
err := tx.Commit()
assertNoError(t, err)
events, err := readEventsFile(logPath)
assertNoError(t, err)
assertLen(t, 1, len(events))
assertLen(t, 3, len(events[0].Queries))
})
t.Run("data persisted after commit", func(t *testing.T) {
resetGlobalState()
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "persist_test", "name", 42)
tx.Commit()
var count int
err := db.QueryRow("SELECT COUNT(*) FROM test_table WHERE id = ?", "persist_test").Scan(&count)
assertNoError(t, err)
assertEqual(t, 1, count)
})
t.Run("event has valid timestamp", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "ts_test", "name", 1)
tx.Commit()
events, _ := readEventsFile(logPath)
// Timestamp should be recent (within last 5 seconds)
elapsed := time.Since(events[0].Timestamp)
assertTrue(t, elapsed < 5*time.Second, "timestamp should be recent")
})
}
func TestLoggedTx_Rollback(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("discards recorded queries", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "id1", "name1", 42)
assertLen(t, 1, len(tx.queries))
err := tx.Rollback()
assertNoError(t, err)
// Queries should be nil after rollback
tx.mu.Lock()
queries := tx.queries
tx.mu.Unlock()
assertNil(t, queries)
})
t.Run("does not write event to file", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test_tool")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "id1", "name1", 42)
err := tx.Rollback()
assertNoError(t, err)
// No file should be created
_, err = os.Stat(logPath)
assertTrue(t, os.IsNotExist(err), "file should not exist")
})
t.Run("data not persisted", func(t *testing.T) {
resetGlobalState()
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "rb_test", "name", 42)
tx.Rollback()
var count int
err := db.QueryRow("SELECT COUNT(*) FROM test_table WHERE id = ?", "rb_test").Scan(&count)
assertNoError(t, err)
assertEqual(t, 0, count)
})
t.Run("rollback returns nil on success", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "x", "y", 1)
err := tx.Rollback()
assertNoError(t, err)
})
}
func TestLoggedTx_QueryMethods(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
db := setupTestDB(t)
defer db.Close()
// Setup: insert a row
tx, _ := BeginLoggedTx(context.Background(), db, "test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "q1", "name1", 42)
tx.Commit()
t.Run("QueryRowContext returns row", func(t *testing.T) {
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
var name string
err := tx.QueryRowContext(context.Background(), "SELECT name FROM test_table WHERE id = ?", "q1").Scan(&name)
assertNoError(t, err)
assertEqual(t, "name1", name)
})
t.Run("QueryRow returns row", func(t *testing.T) {
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
var value int
err := tx.QueryRow("SELECT value FROM test_table WHERE id = ?", "q1").Scan(&value)
assertNoError(t, err)
assertEqual(t, 42, value)
})
t.Run("QueryContext returns rows", func(t *testing.T) {
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
rows, err := tx.QueryContext(context.Background(), "SELECT * FROM test_table")
assertNoError(t, err)
defer rows.Close()
count := 0
for rows.Next() {
count++
}
assertGreater(t, int64(count), 0)
})
t.Run("Query returns rows", func(t *testing.T) {
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
rows, err := tx.Query("SELECT * FROM test_table")
assertNoError(t, err)
defer rows.Close()
assertTrue(t, rows.Next(), "should have at least one row")
})
t.Run("query methods not recorded", func(t *testing.T) {
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
tx.QueryRowContext(context.Background(), "SELECT * FROM test_table")
tx.QueryContext(context.Background(), "SELECT * FROM test_table")
assertLen(t, 0, len(tx.queries))
})
}
func TestLoggedTx_Prepare(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("valid prepare", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, err := tx.PrepareContext(context.Background(),
"INSERT INTO test_table VALUES (?, ?, ?)")
assertNoError(t, err)
assertNotNil(t, stmt)
assertEqual(t, "INSERT INTO test_table VALUES (?, ?, ?)", stmt.sql)
stmt.Close()
})
t.Run("prepare without context", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, err := tx.Prepare("INSERT INTO test_table VALUES (?, ?, ?)")
assertNoError(t, err)
assertNotNil(t, stmt)
stmt.Close()
})
t.Run("invalid SQL returns error", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, err := tx.Prepare("INVALID SQL SYNTAX !!!")
assertError(t, err)
assertNil(t, stmt)
})
}
func TestLoggedStmt_ExecContext(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("INSERT with prepared stmt", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, _ := tx.PrepareContext(context.Background(),
"INSERT INTO test_table VALUES (?, ?, ?)")
defer stmt.Close()
_, err := stmt.ExecContext(context.Background(), "ps1", "name1", 42)
assertNoError(t, err)
assertLen(t, 1, len(tx.queries))
assertContains(t, tx.queries[0].SQL, "INSERT")
})
t.Run("multiple executions recorded separately", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, _ := tx.PrepareContext(context.Background(),
"INSERT INTO test_table VALUES (?, ?, ?)")
defer stmt.Close()
stmt.ExecContext(context.Background(), "ps1", "name1", 1)
stmt.ExecContext(context.Background(), "ps2", "name2", 2)
stmt.ExecContext(context.Background(), "ps3", "name3", 3)
assertLen(t, 3, len(tx.queries))
})
t.Run("parameters captured correctly", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, _ := tx.PrepareContext(context.Background(),
"INSERT INTO test_table VALUES (?, ?, ?)")
defer stmt.Close()
stmt.ExecContext(context.Background(), "captured_id", "captured_name", 999)
assertLen(t, 3, len(tx.queries[0].Parameters))
assertEqual(t, "captured_id", tx.queries[0].Parameters[0])
assertEqual(t, "captured_name", tx.queries[0].Parameters[1])
assertEqual(t, 999, tx.queries[0].Parameters[2])
})
t.Run("SELECT prepared stmt not recorded", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
// First insert some data
tx, _ := BeginLoggedTx(context.Background(), db, "test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "sel_test", "name", 1)
tx.Commit()
// Now test SELECT prepared statement
tx, _ = BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, _ := tx.PrepareContext(context.Background(),
"SELECT name FROM test_table WHERE id = ?")
defer stmt.Close()
var name string
err := stmt.QueryRowContext(context.Background(), "sel_test").Scan(&name)
assertNoError(t, err)
assertEqual(t, "name", name)
assertLen(t, 0, len(tx.queries))
})
t.Run("failed execution not recorded", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
// Insert one row
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "dup_id", "name", 1)
// Try to insert duplicate (will fail due to primary key)
stmt, _ := tx.PrepareContext(context.Background(),
"INSERT INTO test_table VALUES (?, ?, ?)")
defer stmt.Close()
_, err := stmt.ExecContext(context.Background(), "dup_id", "name2", 2)
assertError(t, err)
// Only first INSERT should be recorded
assertLen(t, 1, len(tx.queries))
})
t.Run("commit writes all prepared stmt queries", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "prep_commit_test")
stmt, _ := tx.PrepareContext(context.Background(),
"INSERT INTO test_table VALUES (?, ?, ?)")
stmt.ExecContext(context.Background(), "pc1", "name1", 1)
stmt.ExecContext(context.Background(), "pc2", "name2", 2)
stmt.Close()
tx.Commit()
events, err := readEventsFile(logPath)
assertNoError(t, err)
assertLen(t, 1, len(events))
assertLen(t, 2, len(events[0].Queries))
})
t.Run("Exec without context", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, _ := tx.PrepareContext(context.Background(),
"INSERT INTO test_table VALUES (?, ?, ?)")
defer stmt.Close()
_, err := stmt.Exec("exec_id", "name", 42)
assertNoError(t, err)
assertLen(t, 1, len(tx.queries))
})
}
func TestLoggedStmt_QueryMethods(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
db := setupTestDB(t)
defer db.Close()
// Setup: insert data
tx, _ := BeginLoggedTx(context.Background(), db, "test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "qry1", "name1", 42)
tx.Commit()
t.Run("QueryRowContext returns row", func(t *testing.T) {
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, _ := tx.PrepareContext(context.Background(),
"SELECT name FROM test_table WHERE id = ?")
defer stmt.Close()
var name string
err := stmt.QueryRowContext(context.Background(), "qry1").Scan(&name)
assertNoError(t, err)
assertEqual(t, "name1", name)
})
t.Run("QueryRow returns row", func(t *testing.T) {
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, _ := tx.PrepareContext(context.Background(),
"SELECT value FROM test_table WHERE id = ?")
defer stmt.Close()
var value int
err := stmt.QueryRow("qry1").Scan(&value)
assertNoError(t, err)
assertEqual(t, 42, value)
})
t.Run("QueryContext returns rows", func(t *testing.T) {
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, _ := tx.PrepareContext(context.Background(),
"SELECT * FROM test_table WHERE id = ?")
defer stmt.Close()
rows, err := stmt.QueryContext(context.Background(), "qry1")
assertNoError(t, err)
defer rows.Close()
assertTrue(t, rows.Next(), "should have one row")
})
t.Run("Query returns rows", func(t *testing.T) {
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, _ := tx.PrepareContext(context.Background(),
"SELECT * FROM test_table")
defer stmt.Close()
rows, err := stmt.Query()
assertNoError(t, err)
defer rows.Close()
assertTrue(t, rows.Next(), "should have at least one row")
})
}
func TestLoggedStmt_Close(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("close returns nil on success", func(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "test")
defer tx.Rollback()
stmt, _ := tx.PrepareContext(context.Background(),
"INSERT INTO test_table VALUES (?, ?, ?)")
err := stmt.Close()
assertNoError(t, err)
})
}
func TestEnsureEventLogFile(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("creates file if doesn't exist", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
err := ensureEventLogFile()
assertNoError(t, err)
assertNotNil(t, eventLogFile)
// File should exist
_, err = os.Stat(logPath)
assertNoError(t, err)
})
t.Run("appends to existing file", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
// Create file with content
os.WriteFile(logPath, []byte("existing content\n"), 0644)
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
err := ensureEventLogFile()
assertNoError(t, err)
// File should still have content
data, _ := os.ReadFile(logPath)
assertContains(t, string(data), "existing content")
})
t.Run("creates directory if doesn't exist", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "subdir", "deep", "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
err := ensureEventLogFile()
assertNoError(t, err)
// Directory should exist
dir := filepath.Dir(logPath)
_, err = os.Stat(dir)
assertNoError(t, err)
})
t.Run("returns nil if file already open", func(t *testing.T) {
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
ensureEventLogFile()
firstFile := eventLogFile
err := ensureEventLogFile()
assertNoError(t, err)
// Should reuse same file handle
assertEqual(t, firstFile, eventLogFile)
})
}
func TestTransactionEventJSON(t *testing.T) {
resetGlobalState()
defer resetGlobalState()
t.Run("complete event serializes correctly", func(t *testing.T) {
event := TransactionEvent{
ID: "test-id-12345",
Timestamp: time.Date(2026, 2, 18, 14, 30, 0, 0, time.UTC),
Tool: "test_tool",
Queries: []QueryRecord{
{SQL: "INSERT INTO test VALUES (?)", Parameters: []interface{}{"a"}},
{SQL: "UPDATE test SET x = ?", Parameters: []interface{}{1}},
},
Success: true,
Duration: 42,
}
data, err := json.Marshal(event)
assertNoError(t, err)
var result map[string]interface{}
err = json.Unmarshal(data, &result)
assertNoError(t, err)
assertEqual(t, "test-id-12345", result["id"])
assertEqual(t, "test_tool", result["tool"])
assertEqual(t, true, result["success"])
assertEqual(t, 42.0, result["duration_ms"])
})
t.Run("timestamp in RFC3339Nano format", func(t *testing.T) {
event := TransactionEvent{
ID: "ts-test",
Timestamp: time.Date(2026, 2, 18, 14, 30, 0, 123456789, time.UTC),
Success: true,
}
data, err := json.Marshal(event)
assertNoError(t, err)
var result map[string]interface{}
json.Unmarshal(data, &result)
assertContains(t, result["timestamp"].(string), "2026-02-18T14:30:00.123456789Z")
})
t.Run("duration positive", func(t *testing.T) {
event := TransactionEvent{
ID: "dur-test",
Timestamp: time.Now(),
Success: true,
Duration: 123,
}
data, _ := json.Marshal(event)
var result map[string]interface{}
json.Unmarshal(data, &result)
assertGreater(t, int64(result["duration_ms"].(float64)), 0)
})
t.Run("ID is 21 characters in real usage", func(t *testing.T) {
// Verify by creating an actual event
resetGlobalState()
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "events.jsonl")
SetEventLogConfig(EventLogConfig{Enabled: true, Path: logPath})
db := setupTestDB(t)
defer db.Close()
tx, _ := BeginLoggedTx(context.Background(), db, "id_test")
tx.ExecContext(context.Background(), "INSERT INTO test_table VALUES (?, ?, ?)", "id_test", "name", 1)
tx.Commit()
events, _ := readEventsFile(logPath)
assertLen(t, 21, len(events[0].ID))
})
}