Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
486 changes: 486 additions & 0 deletions feature/dynamodb/entitymanager/README.md

Large diffs are not rendered by default.

247 changes: 247 additions & 0 deletions feature/dynamodb/entitymanager/batch_e2e_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
package entitymanager

import (
"context"
"fmt"
"sync"
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
)

func getTables(t *testing.T, count int) []*Table[order] {
cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
t.Fatalf("Error loading config: %v", err)
}
c := dynamodb.NewFromConfig(cfg)

tables := make([]*Table[order], count)

wg := sync.WaitGroup{}
errChan := make(chan error, len(tables))
defer close(errChan)

for i := range len(tables) {
sch, err := NewSchema[order]()
if err != nil {
t.Fatalf("NewTable() error: %v", err)
}

tableName := fmt.Sprintf("test_batch_e2e_%s_%d", time.Now().Format("2006_01_02_15_04_05.000000000"), i)

sch.WithTableName(&tableName)

tables[i], err = NewTable[order](
c,
WithSchema(sch),
)
if err != nil {
t.Fatalf("NewTable() error: %v", err)
}

wg.Add(1)

go func(table *Table[order]) {
defer wg.Done()

if err := table.CreateWithWait(context.Background(), time.Minute); err != nil {
errChan <- err
}
}(tables[i])
}

wg.Wait()

if len(errChan) > 0 {
for err := range errChan {
if err != nil {
t.Fatalf("CreateWithWait() error: %v", err)
}
}
}

return tables
}

func TestTableBatchE2E(t *testing.T) {
t.Parallel()

numItems := 32

cfg, err := config.LoadDefaultConfig(context.Background())
if err != nil {
t.Fatalf("Error loading config: %v", err)
}
c := dynamodb.NewFromConfig(cfg)

sch, err := NewSchema[order]()
if err != nil {
t.Fatalf("NewTable() error: %v", err)
}

tableName := fmt.Sprintf("test_batch_e2e_%s", time.Now().Format("2006_01_02_15_04_05.000000000"))
sch.WithTableName(&tableName)

table, err := NewTable[order](
c,
WithSchema(sch),
)
if err != nil {
t.Fatalf("NewTable() error: %v", err)
}

if err := table.CreateWithWait(context.Background(), time.Minute); err != nil {
t.Fatalf("Error during CreateWithWait(): %v", err)
}

batchWrite := table.CreateBatchWriteOperation()
now := time.Now()
for c := range numItems {
batchWrite.AddPut(&order{
OrderID: fmt.Sprintf("order:%d", c),
CreatedAt: now.Unix(),
CustomerID: fmt.Sprintf("customer:%d", c),
TotalAmount: 42.1337,
customerNote: fmt.Sprintf("note:%d", c),
address: address{
Street: fmt.Sprintf("steet:%d", c),
City: fmt.Sprintf("city:%d", c),
Zip: fmt.Sprintf("zip:%d", c),
},
})
}

if err := batchWrite.Execute(context.Background()); err != nil {
t.Fatalf("Error during Execute(): %v", err)
}

batchGet := table.CreateBatchGetOperation()
for c := range numItems {
batchGet.AddReadItemByMap(Map{}.With("order_id", fmt.Sprintf("order:%d", c)).With("created_at", now.Unix()))
}

items := make([]*order, 0, 32)
for res := range batchGet.Execute(context.Background()) {
if res.Error() != nil {
t.Errorf("Error during get: %v", res.Error())

continue
}

items = append(items, res.Item())
}

if len(items) != numItems {
t.Errorf("Expected to fetch %d number, got %d", numItems, len(items))
}

defer func() {
if err := table.DeleteWithWait(context.Background(), time.Minute); err != nil {
t.Fatalf("Error during DeleteWithWait(): %v", err)
}
}()
}

func TestTableMultiBatchE2E(t *testing.T) {
t.Parallel()

numItems := 32
numTables := 3

tables := getTables(t, numTables) // must be higher than 2
t.Cleanup(func() {
errChan := make(chan error, len(tables))
defer close(errChan)
wg := sync.WaitGroup{}

for _, table := range tables {
wg.Add(1)

go func(table *Table[order]) {
defer wg.Done()

if err := table.DeleteWithWait(context.Background(), time.Minute); err != nil {
errChan <- err
}
}(table)
}

wg.Wait()

if len(errChan) > 0 {
for err := range errChan {
if err != nil {
t.Fatalf("DeleteWithWait() error: %v", err)
}
}
}
})

now := time.Now()

// put items
batchWrites := make([]*BatchWriteOperation[order], len(tables))
for i := range tables {
batchWrites[i] = tables[i].CreateBatchWriteOperation()
for c := range numItems {
batchWrites[i].AddPut(&order{
OrderID: fmt.Sprintf("order:%d", c),
CreatedAt: now.Unix(),
CustomerID: fmt.Sprintf("customer:%d", c),
TotalAmount: 42.1337,
customerNote: fmt.Sprintf("note:%d", c),
address: address{
Street: fmt.Sprintf("steet:%d", c),
City: fmt.Sprintf("city:%d", c),
Zip: fmt.Sprintf("zip:%d", c),
},
})
}
}

writeExecutor := batchWrites[0].Merge(batchWrites[1])
for c := 2; c < len(batchWrites); c++ {
writeExecutor = writeExecutor.Merge(batchWrites[c])
}
if err := writeExecutor.Execute(context.Background()); err != nil {
t.Errorf("Execute() error: %v", err)
}

// read items
batchGets := make([]*BatchGetOperation[order], len(tables))
for i := range tables {
batchGets[i] = tables[i].CreateBatchGetOperation()
for c := range numItems {
batchGets[i].AddReadItemByMap(Map{}.With("order_id", fmt.Sprintf("order:%d", c)).With("created_at", now.Unix()))
}
}

getExecutor := batchGets[0].Merge(batchGets[1])
for c := 2; c < len(batchGets); c++ {
getExecutor = getExecutor.Merge(batchGets[c])
}

found := make(map[string][]string)
for res := range getExecutor.Execute(context.Background()) {
if res.Error() != nil {
t.Errorf("Error during get: %v", res.Error())

continue
}

if item, ok := res.Item().(*order); ok {
f := found[item.OrderID]
f = append(f, res.Table())
found[item.OrderID] = f
}
}

for orderId, tableNames := range found {
if len(tables) != len(tableNames) {
t.Logf(`Order ID "%s" was not found in all tables: %v`, orderId, tableNames)
}
}
}
118 changes: 118 additions & 0 deletions feature/dynamodb/entitymanager/batch_get.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package entitymanager

import (
"context"
"fmt"
"iter"

"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
)

// BatchGetOperation provides a batched read (BatchGetItem) operation for a DynamoDB table.
// It allows adding items to a read queue and executes the batch operation, yielding results as an iterator.
type BatchGetOperation[T any] struct {
client Client
table *Table[T]
schema *Schema[T]

queue []batchOperation
}

// AddReadItem adds an item to the batch read queue by extracting its key using the schema.
// The item must be a pointer to the struct type used by the table.
func (b *BatchGetOperation[T]) AddReadItem(item *T) error {
m, err := b.schema.createKeyMap(item)
if err != nil {
return fmt.Errorf("error calling schema.createKeyMap: %w", err)
}

b.queue = append(b.queue, batchOperation{
typ: batchOperationGet,
item: m,
})

return nil
}

// AddReadItemByMap adds a key map directly to the batch read queue.
// The map should represent the primary key attributes for the table.
func (b *BatchGetOperation[T]) AddReadItemByMap(m Map) error {
b.queue = append(b.queue, batchOperation{
typ: batchOperationGet,
item: m,
})

return nil
}

// Execute performs the batch get operation for all queued items.
// It yields each result (or error) as an ItemResult[*T]. If the table name
// is not set, an error is yielded. Unprocessed keys are re-queued and
// retried until all are processed or the executor's error threshold is hit.
//
// Example usage:
//
// seq := op.Execute(ctx)
// for res := range iter.Chan(seq) { ... }
func (b *BatchGetOperation[T]) Execute(ctx context.Context, optFns ...func(options *dynamodb.Options)) iter.Seq[ItemResult[*T]] {
executor := &BatchGetExecutor[*T]{
client: b.client,
batchers: []batcher{b},
}
return executor.Execute(ctx, optFns...)
}

// tableName returns the DynamoDB table name associated with this batch get operation.
func (b *BatchGetOperation[T]) tableName() string {
return *b.schema.TableName()
}

// queueItem returns the queued batch operation at the given offset, if any.
func (b *BatchGetOperation[T]) queueItem(offset int) (batchOperation, bool) {
if offset >= len(b.queue) {
return batchOperation{}, false
}

return b.queue[offset], true
}

// fromMap decodes a DynamoDB attribute map into a typed item and applies read extensions.
func (b *BatchGetOperation[T]) fromMap(m map[string]types.AttributeValue) (any, error) {
i, err := b.schema.Decode(m)
if err != nil {
return nil, err
}

if err := b.table.applyAfterReadExtensions(i); err != nil {
return nil, err
}

return i, err
}

// maxConsecutiveErrors returns the maximum number of allowed consecutive errors
// before the executor stops processing batch get results.
func (b *BatchGetOperation[T]) maxConsecutiveErrors() uint {
return b.table.options.MaxConsecutiveErrors
}

// Merge creates a new BatchGetExecutor that combines this batch operation with
// additional batchers, allowing multiple tables or operations to be executed
// in a single BatchGetItem call.
func (b *BatchGetOperation[T]) Merge(bs ...batcher) *BatchGetExecutor[any] {
return &BatchGetExecutor[any]{
client: b.client,
batchers: append([]batcher{b}, bs...),
}
}

// NewBatchGetOperation creates a new BatchGetOperation for the given table.
// Use this to perform batched reads (BatchGetItem) for the table's items.
func NewBatchGetOperation[T any](table *Table[T]) *BatchGetOperation[T] {
return &BatchGetOperation[T]{
client: table.client,
table: table,
schema: table.options.Schema,
}
}
Loading