Tx Context

Back

2023-04-19

		
import (
	"context"
	"database/sql"
	"database/sql/driver"
	"github.com/jmoiron/sqlx"
)

// Executor declares all the methods that are common to both a *sqlx.DB or a *sqlx.Tx
// (Or at least those methods you wish to use)
type Executor interface {
	sqlx.QueryerContext
	sqlx.ExecerContext
	SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
	GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error
}

type ContextTx struct {
	context.Context
	executor Executor
}

func (t ContextTx) Tx() driver.Tx {
	e, ok := t.executor.(driver.Tx)
	if !ok {
		return nil
	}
	return e
}

func NewContextTx(ctx context.Context, db *sqlx.DB) (*ContextTx, driver.Tx, error) {
	exr, err := db.BeginTxx(ctx, &sql.TxOptions{})
	if err != nil {
		return nil, nil, err
	}
	return &ContextTx{
		Context:  context.WithValue(ctx, CtxTxKey{}, exr),
		executor: exr,
	}, exr, nil
}

type CtxTxKey struct{}

func WithoutTx(ctx context.Context) context.Context {
	return context.WithValue(ctx, CtxTxKey{}, nil)
}

func GetContextTx(ctx context.Context) *ContextTx {
	tx, _ := ctx.Value(CtxTxKey{}).(Executor)
	if tx == nil {
		return nil
	}
	return &ContextTx{
		Context:  context.WithValue(ctx, CtxTxKey{}, tx),
		executor: tx,
	}
}

type CtxAwareExecutor struct {
	db *sqlx.DB
}

func (c CtxAwareExecutor) txOrNot(ctx context.Context) Executor {
	tx := GetContextTx(ctx)
	if tx != nil {
		return tx.executor
	}
	return c.db
}

func (c CtxAwareExecutor) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
	return c.txOrNot(ctx).QueryContext(ctx, query, args...)
}

func (c CtxAwareExecutor) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) {
	return c.txOrNot(ctx).QueryxContext(ctx, query, args...)
}

func (c CtxAwareExecutor) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *sqlx.Row {
	return c.txOrNot(ctx).QueryRowxContext(ctx, query, args...)
}

func (c CtxAwareExecutor) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
	return c.txOrNot(ctx).ExecContext(ctx, query, args...)
}

func (c CtxAwareExecutor) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
	return c.txOrNot(ctx).SelectContext(ctx, dest, query, args...)
}

func (c CtxAwareExecutor) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
	return c.txOrNot(ctx).GetContext(ctx, dest, query, args...)
}
            
	
}