tx_logger.go
package db
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"time"
gonanoid "github.com/matoous/go-nanoid/v2"
)
// mutator is the local interface that *LoggedTx must satisfy.
// Defined here because db/ is the consumer. Uses Context variants exclusively
// so all DB-facing interfaces compose as compatible subsets of *sql.DB / *sql.Tx.
type mutator interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
// LoggedTx wraps *sql.Tx and records all Exec/ExecContext calls for mutation logging
//
// LoggedTx satisfies the local mutator interface.
var _ mutator = (*LoggedTx)(nil)
type LoggedTx struct {
tx *sql.Tx
queries []QueryRecord
mu sync.Mutex
toolName string
startTime time.Time
}
// QueryRecord represents a single SQL statement with parameters
type QueryRecord struct {
SQL string `json:"sql"`
Parameters []any `json:"parameters"`
}
// TransactionEvent represents a complete transaction for the event log
type TransactionEvent struct {
ID string `json:"id"`
Timestamp time.Time `json:"timestamp"`
Tool string `json:"tool,omitempty"`
Queries []QueryRecord `json:"queries"`
Success bool `json:"success"`
Duration int64 `json:"duration_ms"`
}
// LoggedStmt wraps *sql.Stmt to intercept Exec calls on prepared statements
type LoggedStmt struct {
stmt *sql.Stmt
tx *LoggedTx
sql string
}
// EventLogConfig holds configuration for event logging
type EventLogConfig struct {
Enabled bool
Path string
}
var (
eventLogConfig EventLogConfig
eventLogMu sync.Mutex
eventLogFile *os.File
eventLogEnc *json.Encoder
)
// SetEventLogConfig configures event logging globally
func SetEventLogConfig(cfg EventLogConfig) {
eventLogMu.Lock()
defer eventLogMu.Unlock()
// Close existing file if path changed
if eventLogFile != nil && eventLogConfig.Path != cfg.Path {
_ = eventLogFile.Close()
eventLogFile = nil
eventLogEnc = nil
}
eventLogConfig = cfg
}
// BeginLoggedTx starts a new transaction that logs all mutations
// toolName is optional and identifies which tool initiated the transaction
func BeginLoggedTx(ctx context.Context, db *sql.DB, toolName string) (*LoggedTx, error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &LoggedTx{
tx: tx,
queries: make([]QueryRecord, 0),
toolName: toolName,
startTime: time.Now(),
}, nil
}
// ExecContext executes and records the SQL statement if it's a mutation
func (l *LoggedTx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
result, err := l.tx.ExecContext(ctx, query, args...)
if err == nil && isMutation(query) {
l.mu.Lock()
l.queries = append(l.queries, QueryRecord{
SQL: query,
Parameters: args,
})
l.mu.Unlock()
}
return result, err
}
// Exec executes and records the SQL statement if it's a mutation
func (l *LoggedTx) Exec(query string, args ...any) (sql.Result, error) {
return l.ExecContext(context.Background(), query, args...)
}
// QueryRowContext delegates to underlying tx (not logged - read operation)
func (l *LoggedTx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
return l.tx.QueryRowContext(ctx, query, args...)
}
// QueryRow delegates to underlying tx (not logged - read operation)
func (l *LoggedTx) QueryRow(query string, args ...any) *sql.Row {
return l.tx.QueryRow(query, args...)
}
// QueryContext delegates to underlying tx (not logged - read operation)
func (l *LoggedTx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return l.tx.QueryContext(ctx, query, args...)
}
// Query delegates to underlying tx (not logged - read operation)
func (l *LoggedTx) Query(query string, args ...any) (*sql.Rows, error) {
return l.tx.Query(query, args...)
}
// UnderlyingTx returns the underlying *sql.Tx.
//
// DEPRECATED: Using UnderlyingTx() bypasses the LoggedTx audit trail.
// Pass the LoggedTx directly (it satisfies the local mutator interface) or use its
// methods instead. This method will be removed in a future version.
func (l *LoggedTx) UnderlyingTx() *sql.Tx {
return l.tx
}
// PrepareContext creates a logged prepared statement
func (l *LoggedTx) PrepareContext(ctx context.Context, query string) (*LoggedStmt, error) {
stmt, err := l.tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &LoggedStmt{stmt: stmt, tx: l, sql: query}, nil
}
// Prepare creates a logged prepared statement
func (l *LoggedTx) Prepare(query string) (*LoggedStmt, error) {
return l.PrepareContext(context.Background(), query)
}
// Rollback rolls back the transaction (discards recorded queries)
func (l *LoggedTx) Rollback() error {
l.mu.Lock()
l.queries = nil // Discard recorded queries
l.mu.Unlock()
return l.tx.Rollback()
}
// Commit commits the transaction and logs all recorded queries on success
func (l *LoggedTx) Commit() error {
err := l.tx.Commit()
if err != nil {
return err
}
// Log on success only
l.mu.Lock()
queries := l.queries
l.mu.Unlock()
if len(queries) > 0 && eventLogConfig.Enabled {
l.writeEvent(queries)
}
return nil
}
// writeEvent writes the transaction to the event log
func (l *LoggedTx) writeEvent(queries []QueryRecord) {
eventLogMu.Lock()
defer eventLogMu.Unlock()
if !eventLogConfig.Enabled {
return
}
// Ensure file is open
if err := ensureEventLogFile(); err != nil {
// Log to stderr but don't fail the commit
fmt.Fprintf(os.Stderr, "Warning: failed to open event log: %v\n", err)
return
}
id, err := gonanoid.New(21)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to generate event ID: %v\n", err)
return
}
event := TransactionEvent{
ID: id,
Timestamp: time.Now(),
Tool: l.toolName,
Queries: queries,
Success: true,
Duration: time.Since(l.startTime).Milliseconds(),
}
if err := eventLogEnc.Encode(event); err != nil {
fmt.Fprintf(os.Stderr, "Warning: failed to write event log: %v\n", err)
}
}
// LoggedStmt methods
// ExecContext executes the prepared statement and logs if it's a mutation
func (s *LoggedStmt) ExecContext(ctx context.Context, args ...any) (sql.Result, error) {
result, err := s.stmt.ExecContext(ctx, args...)
if err == nil && isMutation(s.sql) {
s.tx.mu.Lock()
s.tx.queries = append(s.tx.queries, QueryRecord{
SQL: s.sql,
Parameters: args,
})
s.tx.mu.Unlock()
}
return result, err
}
// Exec executes the prepared statement and logs if it's a mutation
func (s *LoggedStmt) Exec(args ...any) (sql.Result, error) {
return s.ExecContext(context.Background(), args...)
}
// QueryRowContext delegates to underlying statement
func (s *LoggedStmt) QueryRowContext(ctx context.Context, args ...any) *sql.Row {
return s.stmt.QueryRowContext(ctx, args...)
}
// QueryRow delegates to underlying statement
func (s *LoggedStmt) QueryRow(args ...any) *sql.Row {
return s.stmt.QueryRow(args...)
}
// QueryContext delegates to underlying statement
func (s *LoggedStmt) QueryContext(ctx context.Context, args ...any) (*sql.Rows, error) {
return s.stmt.QueryContext(ctx, args...)
}
// Query delegates to underlying statement
func (s *LoggedStmt) Query(args ...any) (*sql.Rows, error) {
return s.stmt.Query(args...)
}
// Close closes the prepared statement
func (s *LoggedStmt) Close() error {
return s.stmt.Close()
}
// isMutation returns true if the SQL is a mutation (INSERT, UPDATE, DELETE)
func isMutation(sqlStr string) bool {
upper := strings.ToUpper(strings.TrimSpace(sqlStr))
// Handle WITH clauses (CTEs) that may contain mutations
if strings.HasPrefix(upper, "WITH") {
// Check for INSERT/UPDATE/DELETE within the query
return strings.Contains(upper, "INSERT") ||
strings.Contains(upper, "UPDATE") ||
strings.Contains(upper, "DELETE")
}
return strings.HasPrefix(upper, "INSERT") ||
strings.HasPrefix(upper, "UPDATE") ||
strings.HasPrefix(upper, "DELETE")
}
// ensureEventLogFile opens the event log file if not already open
func ensureEventLogFile() error {
if eventLogFile != nil {
return nil
}
dir := filepath.Dir(eventLogConfig.Path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create event log directory: %w", err)
}
f, err := os.OpenFile(eventLogConfig.Path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("failed to open event log file: %w", err)
}
eventLogFile = f
eventLogEnc = json.NewEncoder(f)
eventLogEnc.SetEscapeHTML(false)
return nil
}
// CloseEventLog closes the event log file
func CloseEventLog() error {
eventLogMu.Lock()
defer eventLogMu.Unlock()
// Disable logging before closing
eventLogConfig.Enabled = false
if eventLogFile != nil {
err := eventLogFile.Close()
eventLogFile = nil
eventLogEnc = nil
return err
}
return nil
}
// MarshalJSON implements json.Marshaler for QueryRecord
// Handles special types like time.Time, nil, and nullable types
func (q QueryRecord) MarshalJSON() ([]byte, error) {
// Create a helper struct with string parameters
type QueryRecordJSON struct {
SQL string `json:"sql"`
Parameters []any `json:"parameters"`
}
result := QueryRecordJSON{
SQL: q.SQL,
Parameters: make([]any, len(q.Parameters)),
}
for i, param := range q.Parameters {
result.Parameters[i] = marshalParam(param)
}
return json.Marshal(result)
}
// marshalParam converts a parameter to a JSON-serializable value.
// Pointer types (including all *T) are handled via reflection: nil → null,
// non-nil → dereference and recurse.
func marshalParam(param any) any {
if param == nil {
return nil
}
// Handle pointer types via reflection: nil → null, else dereference and recurse.
// This covers all *T cases (including *time.Time) without explicit type switches.
rv := reflect.ValueOf(param)
if rv.Kind() == reflect.Pointer {
if rv.IsNil() {
return nil
}
return marshalParam(rv.Elem().Interface())
}
// Value types
switch v := param.(type) {
case time.Time:
return v.Format(time.RFC3339Nano)
case string:
return v
case int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64,
float32, float64, bool:
return v
case []byte:
return v
default:
return fmt.Sprintf("%v", v)
}
}