Tx Context
Back2023-04-19
import (
"context"
"database/sql"
"database/sql/driver"
"github.com/jmoiron/sqlx"
)
// Executor declares all the methods that are common to both a *sqlx.DB or a *sqlx.Tx
// (Or at least those methods you wish to use)
type Executor interface {
sqlx.QueryerContext
sqlx.ExecerContext
SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
}
type ContextTx struct {
context.Context
executor Executor
}
func (t ContextTx) Tx() driver.Tx {
e, ok := t.executor.(driver.Tx)
if !ok {
return nil
}
return e
}
func NewContextTx(ctx context.Context, db *sqlx.DB) (*ContextTx, driver.Tx, error) {
exr, err := db.BeginTxx(ctx, &sql.TxOptions{})
if err != nil {
return nil, nil, err
}
return &ContextTx{
Context: context.WithValue(ctx, CtxTxKey{}, exr),
executor: exr,
}, exr, nil
}
type CtxTxKey struct{}
func WithoutTx(ctx context.Context) context.Context {
return context.WithValue(ctx, CtxTxKey{}, nil)
}
func GetContextTx(ctx context.Context) *ContextTx {
tx, _ := ctx.Value(CtxTxKey{}).(Executor)
if tx == nil {
return nil
}
return &ContextTx{
Context: context.WithValue(ctx, CtxTxKey{}, tx),
executor: tx,
}
}
type CtxAwareExecutor struct {
db *sqlx.DB
}
func (c CtxAwareExecutor) txOrNot(ctx context.Context) Executor {
tx := GetContextTx(ctx)
if tx != nil {
return tx.executor
}
return c.db
}
func (c CtxAwareExecutor) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return c.txOrNot(ctx).QueryContext(ctx, query, args...)
}
func (c CtxAwareExecutor) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
return c.txOrNot(ctx).QueryxContext(ctx, query, args...)
}
func (c CtxAwareExecutor) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
return c.txOrNot(ctx).QueryRowxContext(ctx, query, args...)
}
func (c CtxAwareExecutor) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return c.txOrNot(ctx).ExecContext(ctx, query, args...)
}
func (c CtxAwareExecutor) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return c.txOrNot(ctx).SelectContext(ctx, dest, query, args...)
}
func (c CtxAwareExecutor) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
return c.txOrNot(ctx).GetContext(ctx, dest, query, args...)
}