diff --git a/feature/dynamodb/entitymanager/README.md b/feature/dynamodb/entitymanager/README.md new file mode 100644 index 00000000000..8b2695958d0 --- /dev/null +++ b/feature/dynamodb/entitymanager/README.md @@ -0,0 +1,486 @@ +# AWS SDK Go V2 - High Level Client + +This package provides an high level DynamoDB client for [AWS SDK Go v2](https://github.com/aws/aws-sdk-go-v2), featuring a flexible data mapper layer. It simplifies object mapping, schema management, and table operations, enabling idiomatic Go struct-to-table mapping, lifecycle hooks, and extension support for DynamoDB applications. + +## Features + +- [Automated struct to schema mapping](#automated-struct-to-schema-mapping) +- [Table management](#table-management) +- [Item operations](#item-operations) +- [Extensions](#extensions) +- [Custom converters](#custom-converters) + +## Automated Struct to Schema Mapping + +The `Schema[T]` type supports advanced table configuration options, allowing you to fine-tune your DynamoDB tables: + +- **Provisioned and On-Demand Throughput:** + - Use `WithProvisionedThroughput` or `WithOnDemandThroughput` to set capacity modes. +- **Table Class:** + - Use `WithTableClass` to set the table class (e.g., Standard, StandardInfrequentAccess). +- **Tags:** + - Use `WithTags` to add tags to your table. +- **Resource Policy, Encryption, Streams, and More:** + - Methods like `WithResourcePolicy`, `WithSSESpecification`, `WithStreamSpecification`, and `WithWarmThroughput` allow further customization. + +**Example:** + +```go +type Product struct { + ProductID string `dynamodbav:"product_id,partition"` + Category string `dynamodbav:",sort"` + Price float64 + InStock bool `dynamodbav:",omitempty"` +} + +// default usage +table, err := entitymanager.NewTable[Product](client) +if err != nil { + // handle error +} + +// customized schema options +schema, err := entitymanager.NewSchema[Product]() +if err != nil { + // handle error +} +schema = schema.WithProvisionedThroughput(&types.ProvisionedThroughput{ + ReadCapacityUnits: 5, + WriteCapacityUnits: 5, +}) +schema = schema.WithTableClass(types.TableClassStandardInfrequentAccess) +schema = schema.WithTags([]types.Tag{{Key: aws.String("env"), Value: aws.String("prod")}}) +table, err = entitymanager.NewTable[Product](client, entitymanager.WithSchema(schema)) +if err != nil { + // handle error +} +``` + +The `Table[T]` type provides a high-level, type-safe interface for managing DynamoDB tables. It abstracts away much of the boilerplate required for table lifecycle management, making it easier to work with DynamoDB in Go. + +**Key table management methods:** + +- `Create(ctx context.Context) (*dynamodb.CreateTableOutput, error)`: Creates the DynamoDB table based on the inferred or provided schema. +- `CreateWithWait(ctx context.Context, maxWaitDur time.Duration) error`: Creates the table and waits until it becomes active, or until the specified timeout is reached. +- `Describe(ctx context.Context) (*dynamodb.DescribeTableOutput, error)`: Retrieves metadata and status information about the table. +- `Delete(ctx context.Context) (*dynamodb.DeleteTableOutput, error)`: Deletes the DynamoDB table. +- `DeleteWithWait(ctx context.Context, maxWaitDur time.Duration) error`: Deletes the table and waits until it is fully removed or the timeout is reached. +- `Exists(ctx context.Context) (bool, error)`: Checks if the table exists and is accessible. + +These features allow you to manage the full lifecycle of your DynamoDB tables in a concise, idiomatic Go style, while leveraging the full power of DynamoDB's management capabilities. + +> **Note:** Table schema updates (such as adding or modifying attributes, indexes, or throughput settings after table creation) are not supported at this time. Only the table management functions listed above are available. Support for table updates is planned for a future release. + +## Item Operations + +The `Table[T]` type provides a set of strongly-typed methods for common item-level operations, making it easy to interact with DynamoDB records as native Go structs. These methods handle marshaling and unmarshaling, key construction, and error handling, so you can focus on your application logic. + +**Key item operations:** + +- `GetItem(ctx, key, ...) (*T, error)`: Retrieve a single item by its key. +- `GetItemWithProjection(ctx, key, projection, ...) (*T, error)`: Retrieve a single item by key, returning only the specified attributes. +- `PutItem(ctx, item, ...) (*T, error)`: Insert or replace an item in the table. +- `UpdateItem(ctx, item, ...) (*T, error)`: Update an existing item, using the struct as the source of changes. +- `DeleteItem(ctx, item, ...) error`: Delete an item by providing the struct value. +- `DeleteItemByKey(ctx, key, ...) error`: Delete an item by its key. +- `Scan(ctx, expr, ...) iter.Seq[ItemResult[*T]]`: Scan the table with a filter expression, returning an iterator over results. +- `ScanIndex(ctx, indexName, expr, ...) iter.Seq[ItemResult[*T]]`: Scan the index with a filter expression, returning an iterator over results. +- `Query(ctx, expr, ...) iter.Seq[ItemResult[*T]]`: Query the table or an index using a key condition expression, returning an iterator over results. +- `QueryIndex(ctx, indexName, expr, ...) iter.Seq[ItemResult[*T]]`: Query the index or an index using a key condition expression, returning an iterator over results. + +**Batch operations:** + +- `CreateBatchWriteOperation() *BatchWriteOperation[T]`: Returns a new batch write operation, allowing you to queue multiple put and delete requests and execute them efficiently in batches. Handles chunking, retries for unprocessed items, and respects DynamoDB's batch size limits. + - Use `AddPut(item *T)` or `AddRawPut(map[string]types.AttributeValue)` to queue items for writing. + - Use `AddDelete(item *T)` or `AddRawDelete(map[string]types.AttributeValue)` to queue items for deletion. + - Call `Execute(ctx, ...)` to perform the batch write. + - Use `Merge(otherBatchers...)` on a `BatchWriteOperation` to create a `BatchWriteExecutor` that can write to multiple tables in a single coordinated workflow. + +- `CreateBatchGetOperation() *BatchGetOperation[T]`: Returns a new batch get operation, allowing you to queue multiple keys for retrieval and execute them in a single batch request. Handles chunking, retries for unprocessed keys, and respects DynamoDB's batch size limits. + - Use `AddReadItem(item *T)` or `AddReadItemByMap(map[string]types.AttributeValue)` to queue keys for retrieval. + - Call `Execute(ctx, ...)` to perform the batch get, which yields results as an iterator. + - Use `Merge(otherBatchers...)` on a `BatchGetOperation` to create a `BatchGetExecutor` that can read from multiple tables in a single `BatchGetItem` workflow. + +Batch operations are useful for efficiently processing large numbers of items, minimizing network calls, and handling DynamoDB's batch constraints automatically. + +These methods are designed to be ergonomic and safe, leveraging Go's type system to reduce boilerplate and runtime errors when working with DynamoDB items. + +**Iterators and ItemResult:** + +Many methods, such as `Scan`, `Query`, and `BatchGetOperation.Execute`, return an iterator in the form of an `iter.Seq[ItemResult[*T]]`, which is a function that accepts a callback. Each callback invocation receives an `ItemResult[*T]` containing either a successfully decoded item (accessible via `Item()`) or an error encountered during retrieval or decoding (accessible via `Error()`). For batch operations that may span multiple tables, `ItemResult` also exposes a `Table()` method that returns the source table name for the item. + +When consuming these iterators, use the callback or range pattern and always check the `Error()` method on each result before using the item: + +```go +// Callback-based iteration (idiomatic for iter.Seq): +table.Scan(ctx, expr, ...)(func(result ItemResult[*T]) bool { + if err := result.Error(); err != nil { + // handle error, e.g. log or collect + return true // continue to next result + } + item := result.Item() + // process item, e.g. append to a slice or print + return true // continue, or return false to stop early +}) +// Alternative: idiomatic Go for-range over the iterator: +for res := range table.Scan(ctx, expr, ...) { + if err := res.Error(); err != nil { + // handle error + continue + } + item := res.Item() + // process item +} +``` + +This pattern ensures robust error handling and makes it easy to process large result sets efficiently and safely. + +### Advanced: merged batch operations + +You can merge batch operations from multiple tables and execute them together. This is useful when you want to minimize network calls and still keep type-safe table APIs. + +**Merged BatchGetOperation (multi-table read):** + +```go +ordersTable, _ := entitymanager.NewTable[Order](client) +customersTable, _ := entitymanager.NewTable[Customer](client) + +ordersBatch := ordersTable.CreateBatchGetOperation() +customersBatch := customersTable.CreateBatchGetOperation() + +// queue keys for both tables +for _, key := range orderKeys { + _ = ordersBatch.AddReadItemByMap(key) +} +for _, key := range customerKeys { + _ = customersBatch.AddReadItemByMap(key) +} + +// merge into a single executor +executor := ordersBatch.Merge(customersBatch) + +for res := range executor.Execute(ctx) { // iter.Seq[ItemResult[any]] + if err := res.Error(); err != nil { + // handle error + continue + } + + switch res.Table() { + case "orders": + order := res.Item().(*Order) + // process order + case "customers": + customer := res.Item().(*Customer) + // process customer + } +} +``` + +**Merged BatchWriteOperation (multi-table write):** + +```go +ordersTable, _ := entitymanager.NewTable[Order](client) +customersTable, _ := entitymanager.NewTable[Customer](client) + +ordersBatch := ordersTable.CreateBatchWriteOperation() +customersBatch := customersTable.CreateBatchWriteOperation() + +// queue writes for both tables +for _, o := range ordersToUpsert { + _ = ordersBatch.AddPut(&o) +} +for _, c := range customersToDelete { + _ = customersBatch.AddDelete(&c) +} + +// merge and execute in a single workflow +executor := ordersBatch.Merge(customersBatch) + +if err := executor.Execute(ctx); err != nil { + // handle error +} +``` + +In these advanced scenarios, the merge APIs (`Merge`) let you coordinate multi-table operations while still using the high-level, type-safe table abstractions provided by this package. Due to Go generics limitations, merged executors always return `ItemResult[any]`, so you must type-assert each item, typically by switching on `res.Table()` and then asserting the concrete type of `res.Item()`. + +## Extensions + +The entity manager supports an extension system that allows you to inject custom logic at key points in the item lifecycle. Extensions can be used for auditing, validation, automatic field population, versioning, atomic counters, and more. + +### Extension Registry and Lifecycle Hooks + +Extensions are registered using the `ExtensionRegistry`, which manages hooks for different operation phases: + +- **BeforeReader / AfterReader:** Invoked before/after reading an item (e.g., `GetItem`). +- **BeforeWriter / AfterWriter:** Invoked before/after writing an item (e.g., `PutItem`, `UpdateItem`). + +You can register multiple extensions for each phase. The registry supports method chaining for easy configuration. + +**Example: Registering extensions** + +```go +reg := &entitymanager.ExtensionRegistry[Product]{} +reg.AddBeforeReader(&MyAuditExtension{}). + AddAfterReader(&MyAuditExtension{}). + AddBeforeWriter(&MyValidationExtension{}) + +table := entitymanager.NewTable[Product](client, entitymanager.WithExtensionRegistry(reg)) +``` + +### Built-in Extensions + +The default registry includes useful extensions for common patterns: + +- **AutogenerateExtension:** Automatically populates fields such as UUIDs or timestamps. +- **AtomicCounterExtension:** Handles atomic increment/decrement fields. +- **VersionExtension:** Implements optimistic versioning for concurrency control. + +You can use the default registry or customize it as needed: + +```go +table := entitymanager.NewTable[Product]( + client, + entitymanager.WithExtensionRegistry( + entitymanager.DefaultExtensionRegistry[Product](), + ), +) +``` + +### Writing a Custom Extension + +To create your own extension, implement one or more of the extension interfaces (e.g., `BeforeWriter`, `AfterReader`). Each hook receives the context and the item, and can return an error to abort the operation (for "before" hooks). + +```go +type MyAuditExtension struct{} + +func (a *MyAuditExtension) BeforeWrite(ctx context.Context, v *Product) error { + log.Printf("Audit: about to write item: %+v", v) + return nil +} + +func (a *MyAuditExtension) AfterRead(ctx context.Context, v *Product) error { + log.Printf("Audit: read item: %+v", v) + return nil +} + + +ext := &MyAuditExtension{} +registry := DefaultExtensionRegistry[order]().Clone() +registry.AddBeforeWriter(ext) +registry.AddAfterReader(ext) + +table, err := NewTable[order]( + c, + WithSchema(sch), + WithExtensionRegistry(registry), +) +``` + +### Advanced: Expression Builders + +Extensions can also participate in building DynamoDB expressions (conditions, filters, projections, updates) by implementing the relevant builder interfaces. This allows for powerful customization of query and update logic. + +See the source and tests for more advanced extension usage patterns. + +### Built-in extensions available in the default registry + +The default extension registry includes several built-in extensions that provide common DynamoDB patterns out of the box: + + +- **AutogenerateExtension** + - Automatically populates struct fields marked for auto-generation. This includes generating UUIDs for primary keys, setting timestamps for created/updated fields, or populating other values at write time. + - **Use cases:** + - Automatically generate unique IDs for new items: + ```go + type Order struct { + ID string `dynamodbav:"id,partition,autogenerated|key"` + CreatedAt string `dynamodbav:"created_at,autogenerated|timestamp"` + } + // On PutItem, ID and CreatedAt will be set if empty. + ``` + - Set audit fields (created/updated timestamps) without manual code. + - **How it works:** + - Fields with the `autogenerated|key` or `autogenerated|timestamp` tag option are detected and set by the extension before writing. + +- **AtomicCounterExtension** + - Enables atomic increment or decrement of numeric fields marked as atomic counters. This is useful for counters, sequence numbers, or version fields that must be updated safely in concurrent environments. + - **Use cases:** + - Track the number of times an item is accessed or updated: + ```go + type PageView struct { + URL string `dynamodbav:"url,partition"` + Counter int64 `dynamodbav:"counter,atomiccounter"` + } + // On UpdateItem, Counter can be atomically incremented. + ``` + - Maintain a version or sequence number for items. + - **How it works:** + - Fields with the `atomiccounter` tag option are updated using DynamoDB's atomic update expressions, ensuring thread-safe increments/decrements. + +- **VersionExtension** + - Implements optimistic concurrency control by managing a version field on your items. This helps prevent lost updates and ensures that concurrent writes do not overwrite each other unintentionally. + - **Use cases:** + - Add a version field to your struct to enable safe concurrent updates: + ```go + type Document struct { + DocID string `dynamodbav:"doc_id,partition"` + Version int64 `dynamodbav:"version,version"` + } + // On each update, Version is checked and incremented. + ``` + - Prevent accidental overwrites in collaborative or distributed systems. + - **How it works:** + - Fields with the `version` tag option are checked and incremented on each write. If the version in the database does not match the expected value, the write fails, preventing lost updates. + +These extensions are automatically included by default via `DefaultExtensionRegistry` when you create a new table, but you can still override or customize the registry if needed: + +```go +table, err := entitymanager.NewTable[Product](client) +if err != nil { + // handle error +} + +// or customize the registry explicitly +table, err := entitymanager.NewTable[Product]( + client, + entitymanager.WithExtensionRegistry( + entitymanager.DefaultExtensionRegistry[Product](), + ), +) +if err != nil { + // handle error +} +``` + +You can also clone and customize the registry to add or remove extensions as needed for your application. + +**Note:** Because extensions can modify how your data is processed, they are enabled by default via `DefaultExtensionRegistry`. If you need different behavior, provide a custom registry with `WithExtensionRegistry` when creating the table. + +## Custom converters + +The entity manager includes a set of built-in converters for common Go types (booleans, numbers, strings, time, JSON, byte arrays, pointers, etc.), making most struct fields work out of the box. + +For advanced scenarios, you can define custom converters to handle complex or non-standard data types in your structs. By implementing the `AttributeConverter` interface, you control how a field is marshaled to and unmarshaled from DynamoDB attribute values. + +**Implementing a custom converter:** + +```go +type MyCustomType struct { + Value string +} + +type MyCustomConverter struct{} + +func (c MyCustomConverter) ToAttributeValue(v any) (types.AttributeValue, error) { + t, ok := v.(MyCustomType) + if !ok { + return nil, fmt.Errorf("expected MyCustomType") + } + return &types.AttributeValueMemberS{Value: "custom:" + t.Value}, nil +} + +func (c MyCustomConverter) FromAttributeValue(av types.AttributeValue) (any, error) { + s, ok := av.(*types.AttributeValueMemberS) + if !ok { + return nil, fmt.Errorf("expected string attribute") + } + return MyCustomType{Value: strings.TrimPrefix(s.Value, "custom:")}, nil +} +``` + +**Registering a custom converter:** + +```go +my_registry := converters.DefaultRegistry.Clone() +// or +my_registry := converters.NewRegistry() + +// register converter +my_registry.Add("my_custom_converter", &MyCustomConverter{}) + +schema, err := entitymanager.NewSchema[MyStruct](func(options *entitymanager.SchemaOptions) { + options.ConverterRegistry = my_registry +}) +if err != nil { + // handle error +} + +// add it to the struct +type MyStruct struct { + ID string `dynamodbav:"id,partition"` + CustomField MyCustomType `dynamodbav:"custom_field,converter|my_custom_converter"` +} +``` + +For more details and built-in examples, see the `converters/` directory in the source tree. + +## Example Usage + +```go +package main + +import ( + "context" + "log" + "time" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager" +) + +type Product struct { + ProductID string `dynamodbav:"product_id,partition"` + Category string `dynamodbav:",sort"` + Price float64 + InStock bool `dynamodbav:",omitempty"` +} + +func main() { + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + log.Fatalf("failed to load config: %v", err) + } + client := dynamodb.NewFromConfig(cfg) + + // Create the table with schema inference + table, err := entitymanager.NewTable[Product](client) + if err != nil { + log.Fatalf("failed to create table: %v", err) + } + if err := table.CreateWithWait(context.Background(), 2*time.Minute); err != nil { + log.Fatalf("failed to create table: %v", err) + } + + // Put an item + prod := &Product{ProductID: "p1", Category: "books", Price: 19.99, InStock: true} + _, err = table.PutItem(context.Background(), prod) + if err != nil { + log.Fatalf("failed to put item: %v", err) + } + + // Get the item + key := entitymanager.Map{}.With("product_id", "p1").With("category", "books") + got, err := table.GetItem(context.Background(), key) + if err != nil { + log.Fatalf("failed to get item: %v", err) + } + log.Printf("Got item: %+v", got) + + // Scan all items (no additional filter in this example) + expr, err := expression.NewBuilder().Build() + if err != nil { + log.Fatalf("failed to build scan expression: %v", err) + } + + for res := range table.Scan(context.Background(), expr) { + if err := res.Error(); err != nil { + log.Printf("scan error: %v", err) + continue + } + item := res.Item() + log.Printf("Scanned item: %+v", item) + } +} +``` diff --git a/feature/dynamodb/entitymanager/batch_e2e_test.go b/feature/dynamodb/entitymanager/batch_e2e_test.go new file mode 100644 index 00000000000..d4b6c7be5c1 --- /dev/null +++ b/feature/dynamodb/entitymanager/batch_e2e_test.go @@ -0,0 +1,247 @@ +package entitymanager + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" +) + +func getTables(t *testing.T, count int) []*Table[order] { + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + t.Fatalf("Error loading config: %v", err) + } + c := dynamodb.NewFromConfig(cfg) + + tables := make([]*Table[order], count) + + wg := sync.WaitGroup{} + errChan := make(chan error, len(tables)) + defer close(errChan) + + for i := range len(tables) { + sch, err := NewSchema[order]() + if err != nil { + t.Fatalf("NewTable() error: %v", err) + } + + tableName := fmt.Sprintf("test_batch_e2e_%s_%d", time.Now().Format("2006_01_02_15_04_05.000000000"), i) + + sch.WithTableName(&tableName) + + tables[i], err = NewTable[order]( + c, + WithSchema(sch), + ) + if err != nil { + t.Fatalf("NewTable() error: %v", err) + } + + wg.Add(1) + + go func(table *Table[order]) { + defer wg.Done() + + if err := table.CreateWithWait(context.Background(), time.Minute); err != nil { + errChan <- err + } + }(tables[i]) + } + + wg.Wait() + + if len(errChan) > 0 { + for err := range errChan { + if err != nil { + t.Fatalf("CreateWithWait() error: %v", err) + } + } + } + + return tables +} + +func TestTableBatchE2E(t *testing.T) { + t.Parallel() + + numItems := 32 + + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + t.Fatalf("Error loading config: %v", err) + } + c := dynamodb.NewFromConfig(cfg) + + sch, err := NewSchema[order]() + if err != nil { + t.Fatalf("NewTable() error: %v", err) + } + + tableName := fmt.Sprintf("test_batch_e2e_%s", time.Now().Format("2006_01_02_15_04_05.000000000")) + sch.WithTableName(&tableName) + + table, err := NewTable[order]( + c, + WithSchema(sch), + ) + if err != nil { + t.Fatalf("NewTable() error: %v", err) + } + + if err := table.CreateWithWait(context.Background(), time.Minute); err != nil { + t.Fatalf("Error during CreateWithWait(): %v", err) + } + + batchWrite := table.CreateBatchWriteOperation() + now := time.Now() + for c := range numItems { + batchWrite.AddPut(&order{ + OrderID: fmt.Sprintf("order:%d", c), + CreatedAt: now.Unix(), + CustomerID: fmt.Sprintf("customer:%d", c), + TotalAmount: 42.1337, + customerNote: fmt.Sprintf("note:%d", c), + address: address{ + Street: fmt.Sprintf("steet:%d", c), + City: fmt.Sprintf("city:%d", c), + Zip: fmt.Sprintf("zip:%d", c), + }, + }) + } + + if err := batchWrite.Execute(context.Background()); err != nil { + t.Fatalf("Error during Execute(): %v", err) + } + + batchGet := table.CreateBatchGetOperation() + for c := range numItems { + batchGet.AddReadItemByMap(Map{}.With("order_id", fmt.Sprintf("order:%d", c)).With("created_at", now.Unix())) + } + + items := make([]*order, 0, 32) + for res := range batchGet.Execute(context.Background()) { + if res.Error() != nil { + t.Errorf("Error during get: %v", res.Error()) + + continue + } + + items = append(items, res.Item()) + } + + if len(items) != numItems { + t.Errorf("Expected to fetch %d number, got %d", numItems, len(items)) + } + + defer func() { + if err := table.DeleteWithWait(context.Background(), time.Minute); err != nil { + t.Fatalf("Error during DeleteWithWait(): %v", err) + } + }() +} + +func TestTableMultiBatchE2E(t *testing.T) { + t.Parallel() + + numItems := 32 + numTables := 3 + + tables := getTables(t, numTables) // must be higher than 2 + t.Cleanup(func() { + errChan := make(chan error, len(tables)) + defer close(errChan) + wg := sync.WaitGroup{} + + for _, table := range tables { + wg.Add(1) + + go func(table *Table[order]) { + defer wg.Done() + + if err := table.DeleteWithWait(context.Background(), time.Minute); err != nil { + errChan <- err + } + }(table) + } + + wg.Wait() + + if len(errChan) > 0 { + for err := range errChan { + if err != nil { + t.Fatalf("DeleteWithWait() error: %v", err) + } + } + } + }) + + now := time.Now() + + // put items + batchWrites := make([]*BatchWriteOperation[order], len(tables)) + for i := range tables { + batchWrites[i] = tables[i].CreateBatchWriteOperation() + for c := range numItems { + batchWrites[i].AddPut(&order{ + OrderID: fmt.Sprintf("order:%d", c), + CreatedAt: now.Unix(), + CustomerID: fmt.Sprintf("customer:%d", c), + TotalAmount: 42.1337, + customerNote: fmt.Sprintf("note:%d", c), + address: address{ + Street: fmt.Sprintf("steet:%d", c), + City: fmt.Sprintf("city:%d", c), + Zip: fmt.Sprintf("zip:%d", c), + }, + }) + } + } + + writeExecutor := batchWrites[0].Merge(batchWrites[1]) + for c := 2; c < len(batchWrites); c++ { + writeExecutor = writeExecutor.Merge(batchWrites[c]) + } + if err := writeExecutor.Execute(context.Background()); err != nil { + t.Errorf("Execute() error: %v", err) + } + + // read items + batchGets := make([]*BatchGetOperation[order], len(tables)) + for i := range tables { + batchGets[i] = tables[i].CreateBatchGetOperation() + for c := range numItems { + batchGets[i].AddReadItemByMap(Map{}.With("order_id", fmt.Sprintf("order:%d", c)).With("created_at", now.Unix())) + } + } + + getExecutor := batchGets[0].Merge(batchGets[1]) + for c := 2; c < len(batchGets); c++ { + getExecutor = getExecutor.Merge(batchGets[c]) + } + + found := make(map[string][]string) + for res := range getExecutor.Execute(context.Background()) { + if res.Error() != nil { + t.Errorf("Error during get: %v", res.Error()) + + continue + } + + if item, ok := res.Item().(*order); ok { + f := found[item.OrderID] + f = append(f, res.Table()) + found[item.OrderID] = f + } + } + + for orderId, tableNames := range found { + if len(tables) != len(tableNames) { + t.Logf(`Order ID "%s" was not found in all tables: %v`, orderId, tableNames) + } + } +} diff --git a/feature/dynamodb/entitymanager/batch_get.go b/feature/dynamodb/entitymanager/batch_get.go new file mode 100644 index 00000000000..20bf48665bb --- /dev/null +++ b/feature/dynamodb/entitymanager/batch_get.go @@ -0,0 +1,118 @@ +package entitymanager + +import ( + "context" + "fmt" + "iter" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// BatchGetOperation provides a batched read (BatchGetItem) operation for a DynamoDB table. +// It allows adding items to a read queue and executes the batch operation, yielding results as an iterator. +type BatchGetOperation[T any] struct { + client Client + table *Table[T] + schema *Schema[T] + + queue []batchOperation +} + +// AddReadItem adds an item to the batch read queue by extracting its key using the schema. +// The item must be a pointer to the struct type used by the table. +func (b *BatchGetOperation[T]) AddReadItem(item *T) error { + m, err := b.schema.createKeyMap(item) + if err != nil { + return fmt.Errorf("error calling schema.createKeyMap: %w", err) + } + + b.queue = append(b.queue, batchOperation{ + typ: batchOperationGet, + item: m, + }) + + return nil +} + +// AddReadItemByMap adds a key map directly to the batch read queue. +// The map should represent the primary key attributes for the table. +func (b *BatchGetOperation[T]) AddReadItemByMap(m Map) error { + b.queue = append(b.queue, batchOperation{ + typ: batchOperationGet, + item: m, + }) + + return nil +} + +// Execute performs the batch get operation for all queued items. +// It yields each result (or error) as an ItemResult[*T]. If the table name +// is not set, an error is yielded. Unprocessed keys are re-queued and +// retried until all are processed or the executor's error threshold is hit. +// +// Example usage: +// +// seq := op.Execute(ctx) +// for res := range iter.Chan(seq) { ... } +func (b *BatchGetOperation[T]) Execute(ctx context.Context, optFns ...func(options *dynamodb.Options)) iter.Seq[ItemResult[*T]] { + executor := &BatchGetExecutor[*T]{ + client: b.client, + batchers: []batcher{b}, + } + return executor.Execute(ctx, optFns...) +} + +// tableName returns the DynamoDB table name associated with this batch get operation. +func (b *BatchGetOperation[T]) tableName() string { + return *b.schema.TableName() +} + +// queueItem returns the queued batch operation at the given offset, if any. +func (b *BatchGetOperation[T]) queueItem(offset int) (batchOperation, bool) { + if offset >= len(b.queue) { + return batchOperation{}, false + } + + return b.queue[offset], true +} + +// fromMap decodes a DynamoDB attribute map into a typed item and applies read extensions. +func (b *BatchGetOperation[T]) fromMap(m map[string]types.AttributeValue) (any, error) { + i, err := b.schema.Decode(m) + if err != nil { + return nil, err + } + + if err := b.table.applyAfterReadExtensions(i); err != nil { + return nil, err + } + + return i, err +} + +// maxConsecutiveErrors returns the maximum number of allowed consecutive errors +// before the executor stops processing batch get results. +func (b *BatchGetOperation[T]) maxConsecutiveErrors() uint { + return b.table.options.MaxConsecutiveErrors +} + +// Merge creates a new BatchGetExecutor that combines this batch operation with +// additional batchers, allowing multiple tables or operations to be executed +// in a single BatchGetItem call. +func (b *BatchGetOperation[T]) Merge(bs ...batcher) *BatchGetExecutor[any] { + return &BatchGetExecutor[any]{ + client: b.client, + batchers: append([]batcher{b}, bs...), + } +} + +// NewBatchGetOperation creates a new BatchGetOperation for the given table. +// Use this to perform batched reads (BatchGetItem) for the table's items. +func NewBatchGetOperation[T any](table *Table[T]) *BatchGetOperation[T] { + return &BatchGetOperation[T]{ + client: table.client, + table: table, + schema: table.options.Schema, + } +} diff --git a/feature/dynamodb/entitymanager/batch_get_executor.go b/feature/dynamodb/entitymanager/batch_get_executor.go new file mode 100644 index 00000000000..4f07dde342a --- /dev/null +++ b/feature/dynamodb/entitymanager/batch_get_executor.go @@ -0,0 +1,142 @@ +package entitymanager + +import ( + "context" + "fmt" + "iter" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +const ( + maxItemsInBatchGet = 25 +) + +// BatchGetExecutor coordinates executing one or more batch get operations +// against DynamoDB, handling chunking, retries, and result decoding. +type BatchGetExecutor[R any] struct { + client Client + batchers []batcher + fromMappers map[string]func(m map[string]types.AttributeValue) (any, error) +} + +// fromMap finds the appropriate mapper for the given table and decodes the +// attribute map into a typed item. +func (b *BatchGetExecutor[R]) fromMap(tableName string, m map[string]types.AttributeValue) (any, error) { + if b.fromMappers == nil { + b.fromMappers = make(map[string]func(m map[string]types.AttributeValue) (any, error)) + } + + if r, ok := b.fromMappers[tableName]; ok { + return r(m) + } + + for _, br := range b.batchers { + if br.tableName() == tableName { + b.fromMappers[tableName] = br.fromMap + + break + } + } + + if r, ok := b.fromMappers[tableName]; ok { + return r(m) + } else { + return nil, fmt.Errorf(`unable to find fromMapper() for table "%s"`, tableName) + } +} + +// Merge creates a new BatchGetExecutor that combines the current batchers with +// additional ones, enabling multi-table batch get execution. +func (b *BatchGetExecutor[R]) Merge(br ...batcher) *BatchGetExecutor[any] { + return &BatchGetExecutor[any]{ + client: b.client, + batchers: append(b.batchers, br...), + } +} + +// Execute runs the batch get requests for all configured batchers, yielding +// each decoded item or error as an ItemResult[R]. It retries unprocessed keys +// until they are exhausted or the maximum consecutive error threshold is +// reached. +func (b *BatchGetExecutor[R]) Execute(ctx context.Context, optFns ...func(options *dynamodb.Options)) iter.Seq[ItemResult[R]] { + // holds the starting point for each table + batchersOffsets := map[string]int{} + + var consecutiveErrors uint = 0 + var maxConsecutiveErrors uint = 0 + + if len(b.batchers) > 0 { + maxConsecutiveErrors = b.batchers[0].maxConsecutiveErrors() + } + + if maxConsecutiveErrors == 0 { + maxConsecutiveErrors = DefaultMaxConsecutiveErrors + } + + return func(yield func(ItemResult[R]) bool) { + remainder := map[string]types.KeysAndAttributes{} + + for { + bgii := &dynamodb.BatchGetItemInput{ + RequestItems: remainder, + } + + done := 0 + for _, items := range remainder { + done += len(items.Keys) + } + + for _, br := range b.batchers { + for ; done < maxItemsInBatchGet; done++ { + offset := batchersOffsets[br.tableName()] + if item, ok := br.queueItem(offset); ok { + ri := bgii.RequestItems[br.tableName()] + ri.Keys = append(ri.Keys, item.item) + bgii.RequestItems[br.tableName()] = ri + } else { + break + } + + batchersOffsets[br.tableName()] = offset + 1 + } + } + + if done == 0 { + return + } + + res, err := b.client.BatchGetItem(ctx, bgii, optFns...) + if err != nil { + if !yield(ItemResult[R]{err: err}) { + return + } + + consecutiveErrors++ + if consecutiveErrors >= maxConsecutiveErrors { + return + } + + continue + } + + consecutiveErrors = 0 + + for tableName, items := range res.Responses { + for _, i := range items { + item, err := b.fromMap(tableName, i) + if !yield(ItemResult[R]{item: item.(R), table: tableName, err: err}) { + return + } + } + } + + if res != nil && res.UnprocessedKeys != nil { + remainder = res.UnprocessedKeys + } else { + remainder = make(map[string]types.KeysAndAttributes) + } + } + } +} diff --git a/feature/dynamodb/entitymanager/batch_get_test.go b/feature/dynamodb/entitymanager/batch_get_test.go new file mode 100644 index 00000000000..0b5bd942417 --- /dev/null +++ b/feature/dynamodb/entitymanager/batch_get_test.go @@ -0,0 +1,138 @@ +package entitymanager + +import ( + "context" + "errors" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" +) + +func TestTableBatchGetItem(t *testing.T) { + cases := []struct { + client Client + expectedError bool + }{ + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withDefaultBatchGetItemCall(nil, map[string]uint{"order": 9}), + withDefaultBatchGetItemCall(nil, map[string]uint{"order": 8}), + withDefaultBatchGetItemCall(nil, map[string]uint{"order": 7}), + withDefaultBatchGetItemCall(nil, map[string]uint{"order": 6}), + withDefaultBatchGetItemCall(nil, map[string]uint{"order": 0}), + // even tho we request initally all the items, we expect 2 items to be left unprocessed + // because we are forcing the UnprocessedKeys to be empty in last call + withExpectFns(expectItemsCount("order", 2)), + ), + }, + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withDefaultBatchGetItemCall(errors.New("1"), map[string]uint{"order": 0}), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + + bgio := table.CreateBatchGetOperation() + + for _, item := range c.client.(*mockClient).Items["order"] { + bgio.AddReadItemByMap(item) + } + + for res := range bgio.Execute(context.Background()) { + if c.expectedError && res.Error() == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && res.Error() != nil { + t.Fatalf("unexpected error: %v", res.Error()) + } + } + }) + } +} + +func TestTableMultiBatchGetItem(t *testing.T) { + cases := []struct { + client Client + expectedError bool + }{ + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withItems("order_backup", makeItem[order], 32), + // as the pool of items for order table is diminished, the request of order_backup will increase + withDefaultBatchGetItemCall(nil, map[string]uint{"order": 9, "order_backup": 0}), + withDefaultBatchGetItemCall(nil, map[string]uint{"order": 8, "order_backup": 2}), + withDefaultBatchGetItemCall(nil, map[string]uint{"order": 7, "order_backup": 9}), + withDefaultBatchGetItemCall(nil, map[string]uint{"order": 6, "order_backup": 15}), + withDefaultBatchGetItemCall(nil, map[string]uint{"order": 0, "order_backup": 4}), + withDefaultBatchGetItemCall(nil, map[string]uint{"order": 0, "order_backup": 0}), + // even tho we request initally all the items, we expect 2 items to be left unprocessed + // because we are forcing the UnprocessedKeys to be empty in last call + withExpectFns(expectItemsCount("order", 2)), + withExpectFns(expectItemsCount("order_backup", 2)), + ), + }, + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withDefaultBatchGetItemCall(errors.New("1"), map[string]uint{"order": 0}), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + table2, err2 := NewTable[order](c.client, func(options *TableOptions[order]) { + sch, _ := NewSchema[order]() + sch.WithTableName(aws.String("order_backup")) + options.Schema = sch + }) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + if err2 != nil { + t.Errorf("unexpcted table error: %v", err) + } + + bgio := table.CreateBatchGetOperation() + bgio2 := table2.CreateBatchGetOperation() + + for _, item := range c.client.(*mockClient).Items["order"] { + bgio.AddReadItemByMap(item) + } + for _, item := range c.client.(*mockClient).Items["order_backup"] { + bgio2.AddReadItemByMap(item) + } + + executor := bgio.Merge(bgio2) + + for res := range executor.Execute(context.Background()) { + if c.expectedError && res.Error() == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && res.Error() != nil { + t.Fatalf("unexpected error: %v", res.Error()) + } + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/batch_operation.go b/feature/dynamodb/entitymanager/batch_operation.go new file mode 100644 index 00000000000..8939942a6ad --- /dev/null +++ b/feature/dynamodb/entitymanager/batch_operation.go @@ -0,0 +1,29 @@ +package entitymanager + +import "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + +type batchOperationType int + +const ( + batchOperationGet = iota + batchOperationPut + batchOperationDelete +) + +func (b batchOperationType) String() string { + switch b { + case batchOperationGet: + return "batchOperationGet" + case batchOperationPut: + return "batchOperationPut" + case batchOperationDelete: + return "batchOperationDelete" + default: + return "unknown" + } +} + +type batchOperation struct { + typ batchOperationType + item map[string]types.AttributeValue +} diff --git a/feature/dynamodb/entitymanager/batch_write.go b/feature/dynamodb/entitymanager/batch_write.go new file mode 100644 index 00000000000..2ccabe99229 --- /dev/null +++ b/feature/dynamodb/entitymanager/batch_write.go @@ -0,0 +1,144 @@ +package entitymanager + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// BatchWriteOperation provides a batched write (BatchWriteItem) operation for a DynamoDB table. +// It allows adding put and delete operations to a queue and executes them in batches. +type BatchWriteOperation[T any] struct { + client Client + table *Table[T] + schema *Schema[T] + + queue []batchOperation +} + +// AddPut adds a put (insert/update) operation to the batch write queue. +// The item is encoded using the table's schema and extensions are applied before writing. +func (b *BatchWriteOperation[T]) AddPut(item *T) error { + if err := b.table.applyBeforeWriteExtensions(item); err != nil { + return fmt.Errorf("error calling table.applyBeforeWriteExtensions(): %w", err) + } + + m, err := b.schema.Encode(item) + if err != nil { + return fmt.Errorf("error calling schema.Encode(): %w", err) + } + + b.queue = append(b.queue, batchOperation{ + typ: batchOperationPut, + item: m, + }) + + return nil +} + +// AddRawPut adds a put operation to the batch write queue using a raw attribute value map. +// The map should represent the full item to be written. +func (b *BatchWriteOperation[T]) AddRawPut(i map[string]types.AttributeValue) error { + if len(i) == 0 { + return fmt.Errorf("input map is empty") + } + + b.queue = append(b.queue, batchOperation{ + typ: batchOperationPut, + item: i, + }) + + return nil +} + +// AddDelete adds a delete operation to the batch write queue for the given item. +// The item's key is extracted using the schema. +func (b *BatchWriteOperation[T]) AddDelete(item *T) error { + m, err := b.schema.createKeyMap(item) + if err != nil { + return fmt.Errorf("error calling schema.createKeyMap(): %w", err) + } + + b.queue = append(b.queue, batchOperation{ + typ: batchOperationDelete, + item: m, + }) + + return nil +} + +// AddRawDelete adds a delete operation to the batch write queue using a raw key map. +// The map should represent the primary key attributes of the item to delete. +func (b *BatchWriteOperation[T]) AddRawDelete(i map[string]types.AttributeValue) error { + if len(i) == 0 { + return fmt.Errorf("input map is empty") + } + + b.queue = append(b.queue, batchOperation{ + typ: batchOperationDelete, + item: i, + }) + + return nil +} + +// tableName returns the DynamoDB table name associated with this batch write operation. +func (b *BatchWriteOperation[T]) tableName() string { + return *b.schema.TableName() +} + +// queueItem returns the queued batch operation at the given offset, if any. +func (b *BatchWriteOperation[T]) queueItem(offset int) (batchOperation, bool) { + if offset >= len(b.queue) { + return batchOperation{}, false + } + + return b.queue[offset], true +} + +// fromMap satisfies the batcher interface for write operations. It returns nil +// because BatchWriteItem does not return items that need decoding. +func (b *BatchWriteOperation[T]) fromMap(_ map[string]types.AttributeValue) (any, error) { + return nil, nil +} + +// maxConsecutiveErrors returns the maximum number of allowed consecutive errors +// before the batch write executor stops processing requests. +func (b *BatchWriteOperation[T]) maxConsecutiveErrors() uint { + return b.table.options.MaxConsecutiveErrors +} + +// Merge creates a new BatchWriteExecutor that combines this batch write +// operation with additional batchers, allowing multiple tables or queues to be +// written in a single BatchWriteItem workflow. +func (b *BatchWriteOperation[T]) Merge(bs ...batcher) *BatchWriteExecutor[any] { + return &BatchWriteExecutor[any]{ + client: b.client, + batchers: append([]batcher{b}, bs...), + } +} + +// Execute performs the batch write operation for all queued put and delete +// requests. It sends requests in batches of up to the maximum BatchWriteItem +// size and retries unprocessed items until they are written or the +// executor's maximum consecutive error threshold is reached. +func (b *BatchWriteOperation[T]) Execute(ctx context.Context, optFns ...func(options *dynamodb.Options)) error { + executor := &BatchWriteExecutor[T]{ + client: b.client, + batchers: []batcher{b}, + } + + return executor.Execute(ctx, optFns...) +} + +// NewBatchWriteOperation creates a new BatchWriteOperation for the given table. +// Use this to perform batched put and delete operations (BatchWriteItem) for the table's items. +func NewBatchWriteOperation[T any](table *Table[T]) *BatchWriteOperation[T] { + return &BatchWriteOperation[T]{ + client: table.client, + table: table, + schema: table.options.Schema, + } +} diff --git a/feature/dynamodb/entitymanager/batch_write_executor.go b/feature/dynamodb/entitymanager/batch_write_executor.go new file mode 100644 index 00000000000..ae5c9aac0fa --- /dev/null +++ b/feature/dynamodb/entitymanager/batch_write_executor.go @@ -0,0 +1,116 @@ +package entitymanager + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +const ( + maxItemsInBatchWrite = 25 // maximum items allowed in a single BatchWriteItem call +) + +// BatchWriteExecutor coordinates executing one or more batch write operations +// against DynamoDB, handling chunking, retries, and error thresholds. +type BatchWriteExecutor[R any] struct { + client Client + batchers []batcher + fromMappers map[string]func(m map[string]types.AttributeValue) (any, error) +} + +// Merge creates a new BatchWriteExecutor that combines the current batchers +// with additional ones, allowing multi-table batch write execution. +func (b *BatchWriteExecutor[R]) Merge(br ...batcher) *BatchWriteExecutor[any] { + return &BatchWriteExecutor[any]{ + client: b.client, + batchers: append(b.batchers, br...), + } +} + +// Execute runs the batch write requests for all configured batchers. It sends +// requests in batches of up to maxItemsInBatchWrite items, and retries +// unprocessed items until they are written or the maximum consecutive error +// threshold is reached. Returns the last error encountered when the threshold +// is exceeded, or nil on success. +func (b *BatchWriteExecutor[R]) Execute(ctx context.Context, optFns ...func(options *dynamodb.Options)) error { + // holds the starting point for each table + batchersOffsets := map[string]int{} + + var consecutiveErrors uint = 0 + var maxConsecutiveErrors uint = 0 + + if len(b.batchers) > 0 { + maxConsecutiveErrors = b.batchers[0].maxConsecutiveErrors() + } + + if maxConsecutiveErrors == 0 { + maxConsecutiveErrors = DefaultMaxConsecutiveErrors + } + + remainder := make(map[string][]types.WriteRequest) + + for { + bwii := &dynamodb.BatchWriteItemInput{ + RequestItems: remainder, + } + done := 0 + + for _, items := range remainder { + done += len(items) + } + + for _, br := range b.batchers { + for ; done < maxItemsInBatchWrite; done++ { + offset := batchersOffsets[br.tableName()] + if item, ok := br.queueItem(offset); ok { + ri := bwii.RequestItems[br.tableName()] + switch item.typ { + case batchOperationPut: + ri = append(ri, types.WriteRequest{ + PutRequest: &types.PutRequest{ + Item: item.item, + }, + }) + case batchOperationDelete: + ri = append(ri, types.WriteRequest{ + DeleteRequest: &types.DeleteRequest{ + Key: item.item, + }, + }) + default: + return fmt.Errorf(`unsupported operation type found: "%s"`, item.typ) + } + bwii.RequestItems[br.tableName()] = ri + } else { + break + } + + batchersOffsets[br.tableName()] = offset + 1 + } + } + + if done == 0 { + break + } + + res, err := b.client.BatchWriteItem(ctx, bwii, optFns...) + if err != nil { + consecutiveErrors++ + if consecutiveErrors >= maxConsecutiveErrors { + return err + } + } + + consecutiveErrors = 0 + + if res != nil && res.UnprocessedItems != nil { + remainder = res.UnprocessedItems + } else { + remainder = make(map[string][]types.WriteRequest) + } + } + + return nil +} diff --git a/feature/dynamodb/entitymanager/batch_write_test.go b/feature/dynamodb/entitymanager/batch_write_test.go new file mode 100644 index 00000000000..02686ac5164 --- /dev/null +++ b/feature/dynamodb/entitymanager/batch_write_test.go @@ -0,0 +1,167 @@ +package entitymanager + +import ( + "context" + "errors" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" +) + +func TestTableBatchWriteItem(t *testing.T) { + cases := []struct { + client Client + isDelete bool + expectedError bool + }{ + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 9}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 8}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 7}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 6}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 0}), + withExpectFns(expectItemsCount("order", 62)), + ), + }, + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 9}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 8}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 7}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 6}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 0}), + withExpectFns(expectItemsCount("order", 2)), + ), + isDelete: true, + }, + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withDefaultBatchWriteItemCall(errors.New("1"), map[string]uint{"order": 0}), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + + bgwo := table.CreateBatchWriteOperation() + + for range 32 { + if c.isDelete { + bgwo.AddRawDelete(makeItem[order]()) + } else { + bgwo.AddRawPut(makeItem[order]()) + } + } + + err = bgwo.Execute(context.Background()) + if c.expectedError && err == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestTableMultiBatchWriteItem(t *testing.T) { + cases := []struct { + client Client + isDelete bool + expectedError bool + }{ + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withItems("order_backup", makeItem[order], 32), + // as the pool of items for order table is diminished, the request of order_backup will increase + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 9, "order_backup": 0}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 8, "order_backup": 2}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 7, "order_backup": 9}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 6, "order_backup": 15}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 0, "order_backup": 4}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 0, "order_backup": 0}), + withExpectFns(expectItemsCount("order", 62)), + withExpectFns(expectItemsCount("order_backup", 62)), + ), + }, + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withItems("order_backup", makeItem[order], 32), + // as the pool of items for order table is diminished, the request of order_backup will increase + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 9, "order_backup": 0}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 8, "order_backup": 2}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 7, "order_backup": 9}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 6, "order_backup": 15}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 0, "order_backup": 4}), + withDefaultBatchWriteItemCall(nil, map[string]uint{"order": 0, "order_backup": 0}), + withExpectFns(expectItemsCount("order", 2)), + withExpectFns(expectItemsCount("order_backup", 2)), + ), + isDelete: true, + }, + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withDefaultBatchWriteItemCall(errors.New("1"), map[string]uint{"order": 0}), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + table2, err2 := NewTable[order](c.client, func(options *TableOptions[order]) { + sch, _ := NewSchema[order]() + sch.WithTableName(aws.String("order_backup")) + options.Schema = sch + }) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + if err2 != nil { + t.Errorf("unexpcted table error: %v", err) + } + + bgwo := table.CreateBatchWriteOperation() + bgwo2 := table2.CreateBatchWriteOperation() + + for range 32 { + if c.isDelete { + bgwo.AddRawDelete(makeItem[order]()) + bgwo2.AddRawDelete(makeItem[order]()) + } else { + bgwo.AddRawPut(makeItem[order]()) + bgwo2.AddRawPut(makeItem[order]()) + } + } + + err = bgwo.Merge(bgwo2).Execute(context.Background()) + if c.expectedError && err == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/batcher.go b/feature/dynamodb/entitymanager/batcher.go new file mode 100644 index 00000000000..0f5a02bd481 --- /dev/null +++ b/feature/dynamodb/entitymanager/batcher.go @@ -0,0 +1,20 @@ +package entitymanager + +import "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" + +// batcher is an internal interface implemented by batch get and batch write +// operations so their executors can treat them uniformly. It exposes access to +// queued operations, the target table name, error thresholds, and optional +// mapping from raw attribute maps to typed items. +type batcher interface { + // queueItem returns the queued batch operation at the given offset, if any. + queueItem(int) (batchOperation, bool) + // tableName returns the DynamoDB table name associated with this batcher. + tableName() string + // maxConsecutiveErrors returns the maximum number of consecutive errors + // allowed before the executor stops processing. + maxConsecutiveErrors() uint + // fromMap converts a DynamoDB attribute map into a typed item, when + // applicable (read operations). For write operations it may be a no-op. + fromMap(m map[string]types.AttributeValue) (any, error) +} diff --git a/feature/dynamodb/entitymanager/client.go b/feature/dynamodb/entitymanager/client.go new file mode 100644 index 00000000000..74eb078c532 --- /dev/null +++ b/feature/dynamodb/entitymanager/client.go @@ -0,0 +1,27 @@ +package entitymanager + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" +) + +// Client defines the minimal DynamoDB client interface required by the entitymanager package. +// It abstracts the AWS SDK DynamoDB client to enable easier testing and extension. +// Any implementation (including mocks) must satisfy this interface to be used with Table and batch operations. +type Client interface { + CreateTable(ctx context.Context, input *dynamodb.CreateTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.CreateTableOutput, error) + DescribeTable(ctx context.Context, input *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) + DeleteTable(ctx context.Context, input *dynamodb.DeleteTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteTableOutput, error) + + GetItem(ctx context.Context, input *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error) + PutItem(ctx context.Context, input *dynamodb.PutItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.PutItemOutput, error) + DeleteItem(ctx context.Context, input *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error) + UpdateItem(ctx context.Context, input *dynamodb.UpdateItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateItemOutput, error) + + BatchGetItem(ctx context.Context, input *dynamodb.BatchGetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.BatchGetItemOutput, error) + BatchWriteItem(ctx context.Context, input *dynamodb.BatchWriteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.BatchWriteItemOutput, error) + + Scan(ctx context.Context, input *dynamodb.ScanInput, optFns ...func(*dynamodb.Options)) (*dynamodb.ScanOutput, error) + Query(ctx context.Context, input *dynamodb.QueryInput, optFns ...func(*dynamodb.Options)) (*dynamodb.QueryOutput, error) +} diff --git a/feature/dynamodb/entitymanager/converters/bool.go b/feature/dynamodb/entitymanager/converters/bool.go new file mode 100644 index 00000000000..00e979d258a --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/bool.go @@ -0,0 +1,44 @@ +package converters + +import ( + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +var _ AttributeConverter[bool] = (*BoolConverter)(nil) + +// BoolConverter converts between Go bool values and DynamoDB BOOL AttributeValues. +// +// BoolConverter only supports the DynamoDB BOOL attribute type: +// +// - FromAttributeValue accepts *types.AttributeValueMemberBOOL and returns the contained bool. +// If the provided AttributeValue is nil (the pointer to the member type is nil), +// it returns ErrNilValue. +// If the provided AttributeValue is any other AttributeValue type, it returns unsupportedType. +// +// - ToAttributeValue converts a Go bool into *types.AttributeValueMemberBOOL. +// +// This converter does NOT interpret strings or numbers as booleans — any non-BOOL +// AttributeValue will cause an unsupportedType error. +type BoolConverter struct{} + +// FromAttributeValue converts a DynamoDB BOOL AttributeValue to a Go bool. +// Returns ErrNilValue for nil pointers, or unsupportedType for unsupported AttributeValue types. +func (n BoolConverter) FromAttributeValue(v types.AttributeValue, _ []string) (bool, error) { + switch av := v.(type) { + case *types.AttributeValueMemberBOOL: + if av == nil { + return false, ErrNilValue + } + + return av.Value, nil + default: + return false, unsupportedType(v, types.AttributeValueMemberBOOL{}) + } +} + +// ToAttributeValue converts a Go bool to a DynamoDB BOOL AttributeValue. +func (n BoolConverter) ToAttributeValue(v bool, _ []string) (types.AttributeValue, error) { + return &types.AttributeValueMemberBOOL{ + Value: v, + }, nil +} diff --git a/feature/dynamodb/entitymanager/converters/bool_ptr.go b/feature/dynamodb/entitymanager/converters/bool_ptr.go new file mode 100644 index 00000000000..600e82f1759 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/bool_ptr.go @@ -0,0 +1,54 @@ +package converters + +import ( + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +var _ AttributeConverter[*bool] = (*BoolPtrConverter)(nil) + +// BoolPtrConverter converts between Go *bool values and DynamoDB BOOL AttributeValues. +// +// BoolPtrConverter behaves similarly to BoolConverter but operates on pointer +// values instead of plain bools. This allows encoding and decoding of nullable +// boolean fields. +// +// Supported conversions: +// +// - FromAttributeValue accepts *types.AttributeValueMemberBOOL and returns +// a Go *bool pointing to the contained value. +// If the provided AttributeValue is nil, it returns ErrNilValue. +// If the provided AttributeValue is any other type, it returns unsupportedType. +// +// - ToAttributeValue converts a Go *bool into *types.AttributeValueMemberBOOL. +// If the input pointer is nil, it returns ErrNilValue. +// +// This converter only handles the DynamoDB BOOL type. All other AttributeValue +// variants will result in an unsupportedType error. +type BoolPtrConverter struct{} + +// FromAttributeValue converts a DynamoDB BOOL AttributeValue to a Go *bool. +// Returns ErrNilValue for nil pointers, or unsupportedType for unsupported AttributeValue types. +func (n BoolPtrConverter) FromAttributeValue(v types.AttributeValue, _ []string) (*bool, error) { + switch av := v.(type) { + case *types.AttributeValueMemberBOOL: + if av == nil { + return nil, ErrNilValue + } + + return &av.Value, nil + default: + return nil, unsupportedType(v, (*types.AttributeValueMemberBOOL)(nil)) + } +} + +// ToAttributeValue converts a Go *bool to a DynamoDB BOOL AttributeValue. +// Returns ErrNilValue if the input pointer is nil. +func (n BoolPtrConverter) ToAttributeValue(v *bool, _ []string) (types.AttributeValue, error) { + if v == nil { + return nil, ErrNilValue + } + + return &types.AttributeValueMemberBOOL{ + Value: *v, + }, nil +} diff --git a/feature/dynamodb/entitymanager/converters/bool_ptr_test.go b/feature/dynamodb/entitymanager/converters/bool_ptr_test.go new file mode 100644 index 00000000000..ffac9122b89 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/bool_ptr_test.go @@ -0,0 +1,90 @@ +package converters + +import ( + "reflect" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestBoolPtrConverter_FromAttributeValue(t *testing.T) { + cases := []struct { + input types.AttributeValue + opts []string + expectedOutput any + expectedError bool + }{ + {input: &types.AttributeValueMemberBOOL{Value: true}, opts: []string{}, expectedOutput: aws.Bool(true), expectedError: false}, + {input: &types.AttributeValueMemberBOOL{Value: false}, opts: []string{}, expectedOutput: aws.Bool(false), expectedError: false}, + // errors + {input: nil, opts: nil, expectedOutput: nil, expectedError: true}, + {input: (types.AttributeValue)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberN)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberS)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberBOOL)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: &types.AttributeValueMemberS{Value: "true"}, opts: nil, expectedOutput: nil, expectedError: true}, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + tc := BoolPtrConverter{} + + actualOutput, actualError := tc.FromAttributeValue(c.input, c.opts) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} + +func TestBoolPtrConverter_ToAttributeValue(t *testing.T) { + cases := []struct { + input *bool + opts []string + expectedOutput types.AttributeValue + expectedError bool + }{ + {input: nil, opts: nil, expectedOutput: nil, expectedError: true}, + {input: aws.Bool(true), opts: nil, expectedOutput: &types.AttributeValueMemberBOOL{Value: true}, expectedError: false}, + {input: aws.Bool(false), opts: nil, expectedOutput: &types.AttributeValueMemberBOOL{Value: false}, expectedError: false}, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + tc := BoolPtrConverter{} + + actualOutput, actualError := tc.ToAttributeValue(c.input, c.opts) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/converters/bool_test.go b/feature/dynamodb/entitymanager/converters/bool_test.go new file mode 100644 index 00000000000..f4755fd6098 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/bool_test.go @@ -0,0 +1,88 @@ +package converters + +import ( + "reflect" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestBoolConverter_FromAttributeValue(t *testing.T) { + cases := []struct { + input types.AttributeValue + opts []string + expectedOutput any + expectedError bool + }{ + {input: &types.AttributeValueMemberBOOL{Value: true}, opts: []string{}, expectedOutput: true, expectedError: false}, + {input: &types.AttributeValueMemberBOOL{Value: false}, opts: []string{}, expectedOutput: false, expectedError: false}, + // errors + {input: nil, opts: nil, expectedOutput: nil, expectedError: true}, + {input: (types.AttributeValue)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberN)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberS)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberBOOL)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: &types.AttributeValueMemberS{Value: "true"}, opts: nil, expectedOutput: nil, expectedError: true}, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + tc := BoolConverter{} + + actualOutput, actualError := tc.FromAttributeValue(c.input, c.opts) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} + +func TestBoolConverter_ToAttributeValue(t *testing.T) { + cases := []struct { + input bool + opts []string + expectedOutput types.AttributeValue + expectedError bool + }{ + {input: true, opts: nil, expectedOutput: &types.AttributeValueMemberBOOL{Value: true}, expectedError: false}, + {input: false, opts: nil, expectedOutput: &types.AttributeValueMemberBOOL{Value: false}, expectedError: false}, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + tc := BoolConverter{} + + actualOutput, actualError := tc.ToAttributeValue(c.input, c.opts) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/converters/byte_array.go b/feature/dynamodb/entitymanager/converters/byte_array.go new file mode 100644 index 00000000000..e2698d011ee --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/byte_array.go @@ -0,0 +1,59 @@ +package converters + +import ( + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// ByteArrayConverter converts between Go byte slices ([]byte) and DynamoDB +// binary (B) AttributeValues. +// +// Behaviour: +// +// FromAttributeValue: +// - Accepts *types.AttributeValueMemberB and returns its Value ([]byte). +// - Returns ErrNilValue if the AttributeValue pointer is nil. +// - Returns unsupportedType if the AttributeValue is not of type B. +// +// ToAttributeValue: +// - Converts a Go []byte into a *types.AttributeValueMemberB containing that byte slice. +// - Returns ErrNilValue if the input slice is nil. +// +// This converter only supports DynamoDB binary attributes and Go []byte values. +// Any other AttributeValue type will cause an unsupportedType error. +// +// Example: +// +// var c converters.ByteArrayConverter +// av, _ := c.ToAttributeValue([]byte("hello"), nil) +// // av == &types.AttributeValueMemberB{Value: []byte("hello")} +// +// v, _ := c.FromAttributeValue(av, nil) +// // v == []byte("hello") +type ByteArrayConverter struct { +} + +// FromAttributeValue converts a DynamoDB binary (B) AttributeValue to a Go []byte. +// Returns ErrNilValue for nil pointers, or unsupportedType for unsupported AttributeValue types. +func (n ByteArrayConverter) FromAttributeValue(v types.AttributeValue, _ []string) ([]byte, error) { + switch av := v.(type) { + case *types.AttributeValueMemberB: + if av == nil { + return nil, ErrNilValue + } + return av.Value, nil + default: + return nil, unsupportedType(v, (*types.AttributeValueMemberB)(nil)) + } +} + +// ToAttributeValue converts a Go []byte to a DynamoDB binary (B) AttributeValue. +// Returns ErrNilValue if the input slice is nil. +func (n ByteArrayConverter) ToAttributeValue(v []byte, _ []string) (types.AttributeValue, error) { + if v == nil { + return nil, ErrNilValue + } + + return &types.AttributeValueMemberB{ + Value: v, + }, nil +} diff --git a/feature/dynamodb/entitymanager/converters/byte_array_test.go b/feature/dynamodb/entitymanager/converters/byte_array_test.go new file mode 100644 index 00000000000..fa77b5eb57b --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/byte_array_test.go @@ -0,0 +1,89 @@ +package converters + +import ( + "reflect" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestByteArrayConverter_FromAttributeValue(t *testing.T) { + cases := []struct { + input types.AttributeValue + opts []string + expectedOutput any + expectedError bool + }{ + {input: &types.AttributeValueMemberB{Value: nil}, opts: []string{}, expectedOutput: ([]byte)(nil), expectedError: false}, + {input: &types.AttributeValueMemberB{Value: []byte("test")}, opts: []string{}, expectedOutput: []byte("test"), expectedError: false}, + // errors + {input: nil, opts: nil, expectedOutput: nil, expectedError: true}, + {input: (types.AttributeValue)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberN)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberS)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberBOOL)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: &types.AttributeValueMemberS{Value: "true"}, opts: nil, expectedOutput: nil, expectedError: true}, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + tc := ByteArrayConverter{} + + actualOutput, actualError := tc.FromAttributeValue(c.input, c.opts) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} + +func TestByteArrayConverter_ToAttributeValue(t *testing.T) { + cases := []struct { + input []uint8 + opts []string + expectedOutput types.AttributeValue + expectedError bool + }{ + {input: nil, opts: nil, expectedOutput: nil, expectedError: true}, + {input: ([]byte)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: []byte("tests"), opts: nil, expectedOutput: &types.AttributeValueMemberB{Value: []byte("tests")}, expectedError: false}, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + tc := ByteArrayConverter{} + + actualOutput, actualError := tc.ToAttributeValue(c.input, c.opts) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/converters/errors.go b/feature/dynamodb/entitymanager/converters/errors.go new file mode 100644 index 00000000000..95deaaafa37 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/errors.go @@ -0,0 +1,28 @@ +package converters + +import ( + "errors" + "fmt" + "strings" +) + +// ErrNilValue is returned when a nil value is encountered where a non-nil value is required. +var ErrNilValue = errors.New("nil value error") + +// unsupportedType returns a formatted error indicating the provided type is not supported. +// Optionally lists the supported types for better diagnostics. +func unsupportedType(unsupported any, supported ...any) error { + err := fmt.Errorf("unsupported type: %T", unsupported) + + if len(supported) > 0 { + var sup []string + + for i := range supported { + sup = append(sup, fmt.Sprintf("%T", supported[i])) + } + + err = fmt.Errorf("expected %s, got %s", strings.Join(sup, " or "), err.Error()) + } + + return err +} diff --git a/feature/dynamodb/entitymanager/converters/json.go b/feature/dynamodb/entitymanager/converters/json.go new file mode 100644 index 00000000000..0430d18fca8 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/json.go @@ -0,0 +1,90 @@ +package converters + +import ( + "encoding/json" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +var _ AttributeConverter[any] = (*JSONConverter)(nil) + +// JSONConverter converts between arbitrary Go values and DynamoDB AttributeValues +// that contain JSON data. +// +// Behaviour (matches the implementation in json.go): +// +// FromAttributeValue: +// - Accepts *types.AttributeValueMemberB: treats av.Value as raw JSON bytes +// and unmarshals into an `any` value. If av is nil, returns ErrNilValue. +// - Accepts *types.AttributeValueMemberS: treats av.Value as a JSON string +// and unmarshals into an `any` value. If av is nil, returns ErrNilValue. +// - Accepts *types.AttributeValueMemberNULL: returns (nil, nil). +// - Any other AttributeValue concrete type returns unsupportedType(...). +// +// ToAttributeValue: +// - If the provided Go value is nil, returns *types.AttributeValueMemberNULL{Value: true}. +// - Marshals the Go value to JSON. If `opts` contains `as=bytes`, returns +// *types.AttributeValueMemberB with the JSON bytes; otherwise returns +// *types.AttributeValueMemberS with the JSON string. +// - Returns any JSON marshal error encountered. +// +// Notes: +// - This converter only recognizes DynamoDB B and S attribute members containing +// JSON payloads (and NULL). It does not convert arbitrary AttributeValue types. +// - The converter returns ErrNilValue when a concrete B or S member pointer is nil, +// and unsupportedType for mismatched AttributeValue types. +type JSONConverter struct { +} + +// FromAttributeValue converts a DynamoDB AttributeValue containing JSON data (as B or S) to a Go value. +// Returns ErrNilValue for nil pointers, or unsupportedType for unsupported AttributeValue types. +func (j JSONConverter) FromAttributeValue(v types.AttributeValue, _ []string) (any, error) { + switch av := v.(type) { + case *types.AttributeValueMemberB: + if av == nil { + return nil, ErrNilValue + } + + var o any + err := json.Unmarshal(av.Value, &o) + return o, err + case *types.AttributeValueMemberS: + if av == nil { + return nil, ErrNilValue + } + + var o any + err := json.Unmarshal([]byte(av.Value), &o) + return o, err + case *types.AttributeValueMemberNULL: + return nil, nil + default: + return "", unsupportedType(v, types.AttributeValueMemberS{}, types.AttributeValueMemberB{}, &types.AttributeValueMemberNULL{}) + } +} + +// ToAttributeValue marshals a Go value to a DynamoDB AttributeValue as JSON (S or B), or NULL if v is nil. +// The "as=bytes" option stores the JSON as binary (B); otherwise, it stores as string (S). +func (j JSONConverter) ToAttributeValue(v any, opts []string) (types.AttributeValue, error) { + as := getOpt(opts, "as") + + if v == nil { + return &types.AttributeValueMemberNULL{Value: true}, nil + } + + val, err := json.Marshal(v) + + if err != nil { + return nil, err + } + + if as == "bytes" { + return &types.AttributeValueMemberB{ + Value: val, + }, nil + } + + return &types.AttributeValueMemberS{ + Value: string(val), + }, nil +} diff --git a/feature/dynamodb/entitymanager/converters/json_test.go b/feature/dynamodb/entitymanager/converters/json_test.go new file mode 100644 index 00000000000..dc22fb67716 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/json_test.go @@ -0,0 +1,160 @@ +package converters + +import ( + "reflect" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestJsonConverter_FromAttributeValue(t *testing.T) { + cases := []struct { + input types.AttributeValue + expectedOutput any + expectedError bool + }{ + { + input: &types.AttributeValueMemberNULL{Value: true}, + expectedOutput: nil, + }, + { + input: &types.AttributeValueMemberNULL{Value: false}, + expectedOutput: nil, + }, + { + input: &types.AttributeValueMemberS{Value: "[]"}, + expectedOutput: []any{}, + }, + { + input: &types.AttributeValueMemberS{Value: "{}"}, + expectedOutput: map[string]any{}, + }, + { + input: &types.AttributeValueMemberS{Value: `{"test":"test"}`}, + expectedOutput: map[string]any{ + "test": "test", + }, + }, + { + input: &types.AttributeValueMemberS{Value: `[{"test":"test"}]`}, + expectedOutput: []any{ + map[string]any{ + "test": "test", + }, + }, + }, + { + input: &types.AttributeValueMemberS{Value: `"test"`}, + expectedOutput: "test", + }, + { + input: &types.AttributeValueMemberS{Value: `[`}, + expectedError: true, + }, + { + input: &types.AttributeValueMemberS{Value: `[{"test":"test}"`}, + expectedError: true, + }, + { + input: &types.AttributeValueMemberS{Value: ""}, + expectedError: true, + }, + { + input: &types.AttributeValueMemberB{Value: []byte{}}, + expectedError: true, + }, + { + input: &types.AttributeValueMemberM{}, + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + j := JSONConverter{} + actualOutput, actualError := j.FromAttributeValue(c.input, nil) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} + +func TestJsonConverter_ToAttributeValue(t *testing.T) { + cases := []struct { + input any + expectedOutput types.AttributeValue + expectedError bool + }{ + { + expectedOutput: &types.AttributeValueMemberNULL{Value: true}, + }, + { + input: "test", + expectedOutput: &types.AttributeValueMemberS{Value: `"test"`}, + }, + { + input: []any{}, + expectedOutput: &types.AttributeValueMemberS{Value: `[]`}, + }, + { + input: map[string]any{}, + expectedOutput: &types.AttributeValueMemberS{Value: `{}`}, + }, + { + input: map[string]any{"test": "test"}, + expectedOutput: &types.AttributeValueMemberS{Value: `{"test":"test"}`}, + }, + { + input: []string{"test"}, + expectedOutput: &types.AttributeValueMemberS{Value: `["test"]`}, + }, + { + input: struct { + Test string `json:"test"` + hidden string `json:"hidden"` + }{ + Test: "test", + hidden: "you-can't-see-me", + }, + expectedOutput: &types.AttributeValueMemberS{Value: `{"test":"test"}`}, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + j := JSONConverter{} + actualOutput, actualError := j.ToAttributeValue(c.input, nil) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/converters/numeric.go b/feature/dynamodb/entitymanager/converters/numeric.go new file mode 100644 index 00000000000..50f4774a794 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/numeric.go @@ -0,0 +1,91 @@ +package converters + +import ( + "fmt" + "strconv" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +type number interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~float32 | ~float64 +} + +var _ AttributeConverter[uint] = (*NumericConverter[uint])(nil) +var _ AttributeConverter[uint8] = (*NumericConverter[uint8])(nil) +var _ AttributeConverter[uint16] = (*NumericConverter[uint16])(nil) +var _ AttributeConverter[uint32] = (*NumericConverter[uint32])(nil) +var _ AttributeConverter[uint64] = (*NumericConverter[uint64])(nil) +var _ AttributeConverter[int] = (*NumericConverter[int])(nil) +var _ AttributeConverter[int8] = (*NumericConverter[int8])(nil) +var _ AttributeConverter[int16] = (*NumericConverter[int16])(nil) +var _ AttributeConverter[int32] = (*NumericConverter[int32])(nil) +var _ AttributeConverter[int64] = (*NumericConverter[int64])(nil) +var _ AttributeConverter[float32] = (*NumericConverter[float32])(nil) +var _ AttributeConverter[float64] = (*NumericConverter[float64])(nil) + +// NumericConverter converts between Go numeric types and DynamoDB number (N) +// AttributeValues. +// +// It is a generic converter parameterized by T, which must satisfy the internal +// `number` constraint (e.g. uint, int64, float64, etc.). +// +// Behaviour: +// +// FromAttributeValue: +// - Accepts *types.AttributeValueMemberN and parses av.Value into the target +// numeric type T using strconv. +// - Returns ErrNilValue if the AttributeValue pointer is nil. +// - Returns unsupportedType if the AttributeValue is not a number type. +// +// ToAttributeValue: +// - Converts a Go numeric value of type T into a *types.AttributeValueMemberN +// with its string representation (via strconv.FormatFloat / strconv.FormatInt). +// +// This converter only supports DynamoDB number attributes and Go numeric types; +// any other AttributeValue type or Go kind will result in an error. +// +// Example: +// +// var c converters.NumericConverter[int] +// av, _ := c.ToAttributeValue(42, nil) +// // av == &types.AttributeValueMemberN{Value: "42"} +// +// v, _ := c.FromAttributeValue(av, nil) +// // v == int(42) +type NumericConverter[T number] struct { +} + +func (n NumericConverter[T]) FromAttributeValue(v types.AttributeValue, i []string) (T, error) { + out := *new(T) + + switch av := v.(type) { + case *types.AttributeValueMemberN: + if strings.Contains(av.Value, ".") { + f, err := strconv.ParseFloat(av.Value, 64) + if err != nil { + return T(0), err + } + + out = T(f) + } else { + i, err := strconv.ParseInt(av.Value, 10, 64) + if err != nil { + return T(0), err + } + + out = T(i) + } + default: + return T(0), unsupportedType(v, (*types.AttributeValueMemberN)(nil)) + } + + return out, nil +} + +func (n NumericConverter[T]) ToAttributeValue(t T, i []string) (types.AttributeValue, error) { + return &types.AttributeValueMemberN{ + Value: fmt.Sprintf("%v", t), + }, nil +} diff --git a/feature/dynamodb/entitymanager/converters/numeric_ptr.go b/feature/dynamodb/entitymanager/converters/numeric_ptr.go new file mode 100644 index 00000000000..12c048be860 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/numeric_ptr.go @@ -0,0 +1,94 @@ +package converters + +import ( + "fmt" + "strconv" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +var _ AttributeConverter[*uint] = (*NumericPtrConverter[uint])(nil) +var _ AttributeConverter[*uint8] = (*NumericPtrConverter[uint8])(nil) +var _ AttributeConverter[*uint16] = (*NumericPtrConverter[uint16])(nil) +var _ AttributeConverter[*uint32] = (*NumericPtrConverter[uint32])(nil) +var _ AttributeConverter[*uint64] = (*NumericPtrConverter[uint64])(nil) +var _ AttributeConverter[*int] = (*NumericPtrConverter[int])(nil) +var _ AttributeConverter[*int8] = (*NumericPtrConverter[int8])(nil) +var _ AttributeConverter[*int16] = (*NumericPtrConverter[int16])(nil) +var _ AttributeConverter[*int32] = (*NumericPtrConverter[int32])(nil) +var _ AttributeConverter[*int64] = (*NumericPtrConverter[int64])(nil) +var _ AttributeConverter[*float32] = (*NumericPtrConverter[float32])(nil) +var _ AttributeConverter[*float64] = (*NumericPtrConverter[float64])(nil) + +// NumericPtrConverter converts between Go pointer-to-number values (*T) +// and DynamoDB number (N) AttributeValues. +// +// It is a generic converter parameterized by T, which must satisfy the internal +// `number` constraint (e.g. uint, int64, float64, etc.). +// +// Behaviour: +// +// FromAttributeValue: +// - Accepts *types.AttributeValueMemberN and parses av.Value into a new value +// of type T, returning a pointer to it (*T). +// - Returns ErrNilValue if the AttributeValue pointer is nil. +// - Returns unsupportedType if the AttributeValue is not a number type. +// +// ToAttributeValue: +// - Converts a Go *T into a *types.AttributeValueMemberN containing the string +// representation of the pointed-to numeric value. +// - Returns ErrNilValue if the input pointer is nil. +// +// This converter only supports DynamoDB number attributes and Go numeric pointer +// types. Any other AttributeValue type or nil handling violation will result in +// an error. +// +// Example: +// +// var c converters.NumericPtrConverter[int] +// v := 42 +// av, _ := c.ToAttributeValue(&v, nil) +// // av == &types.AttributeValueMemberN{Value: "42"} +// +// out, _ := c.FromAttributeValue(av, nil) +// // *out == 42 +type NumericPtrConverter[T number] struct { +} + +func (n NumericPtrConverter[T]) FromAttributeValue(v types.AttributeValue, _ []string) (*T, error) { + out := new(T) + + switch av := v.(type) { + case *types.AttributeValueMemberN: + if strings.Contains(av.Value, ".") { + f, err := strconv.ParseFloat(av.Value, 64) + if err != nil { + return nil, err + } + + *out = T(f) + } else { + i, err := strconv.ParseInt(av.Value, 10, 64) + if err != nil { + return nil, err + } + + *out = T(i) + } + default: + return nil, unsupportedType(v, (*types.AttributeValueMemberN)(nil)) + } + + return out, nil +} + +func (n NumericPtrConverter[T]) ToAttributeValue(v *T, _ []string) (types.AttributeValue, error) { + if v == nil { + return nil, ErrNilValue + } + + return &types.AttributeValueMemberN{ + Value: fmt.Sprintf("%v", *v), + }, nil +} diff --git a/feature/dynamodb/entitymanager/converters/numeric_ptr_test.go b/feature/dynamodb/entitymanager/converters/numeric_ptr_test.go new file mode 100644 index 00000000000..85c580be4b7 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/numeric_ptr_test.go @@ -0,0 +1,348 @@ +package converters + +import ( + "reflect" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestNumericPtrConverter_FromAttributeValue(t *testing.T) { + cases := []struct { + input types.AttributeValue + options []string + expectedOutput any + expectedError bool + }{ + { + input: &types.AttributeValueMemberN{Value: "123"}, + expectedOutput: aws.Uint(123), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "-123"}, + expectedOutput: aws.Uint(18446744073709551493), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: aws.Uint(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123"}, + expectedOutput: aws.Uint8(123), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "1234"}, + expectedOutput: aws.Uint8(210), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "-1234"}, + expectedOutput: aws.Uint8(46), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: aws.Uint8(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123"}, + expectedOutput: aws.Uint16(123), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "123456"}, + expectedOutput: aws.Uint16(57920), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: aws.Uint16(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123"}, + expectedOutput: aws.Uint32(123), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "12345678901"}, + expectedOutput: aws.Uint32(3755744309), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: aws.Uint32(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123"}, + expectedOutput: aws.Uint64(123), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "123456789012345678901234567890123"}, + expectedOutput: aws.Uint64(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: aws.Uint64(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123.10"}, + expectedOutput: aws.Float32(123.10), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "-123.10"}, + expectedOutput: aws.Float32(-123.10), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "123456789012345678901234567890123"}, + expectedOutput: aws.Float32(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123,10"}, + expectedOutput: aws.Float32(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: aws.Float32(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123.10"}, + expectedOutput: aws.Float64(123.10), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "-123.10"}, + expectedOutput: aws.Float64(-123.10), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "123456789012345678901234567890123"}, + expectedOutput: aws.Float64(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123,10"}, + expectedOutput: aws.Float64(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: aws.Float64(0), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var actualError error + var actualOutput any + + switch v := c.expectedOutput.(type) { + case *uint: + cvt := NumericPtrConverter[uint]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case *uint8: + cvt := NumericPtrConverter[uint8]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case *uint16: + cvt := NumericPtrConverter[uint16]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case *uint32: + cvt := NumericPtrConverter[uint32]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case *uint64: + cvt := NumericPtrConverter[uint64]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case *int: + cvt := NumericPtrConverter[int]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case *int8: + cvt := NumericPtrConverter[int8]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case *int16: + cvt := NumericPtrConverter[int16]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case *int32: + cvt := NumericPtrConverter[int32]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case *int64: + cvt := NumericPtrConverter[int64]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case *float32: + cvt := NumericPtrConverter[float32]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case *float64: + cvt := NumericPtrConverter[float64]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + default: + t.Errorf("unsupported type: %T", v) + } + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} + +func TestNumericPtrConverter_ToAttributeValue(t *testing.T) { + cases := []struct { + input any + options []string + expectedOutput types.AttributeValue + expectedError bool + }{ + { + input: aws.Uint(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: aws.Uint8(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: aws.Uint16(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: aws.Uint32(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: aws.Uint64(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: aws.Int(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: aws.Int8(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: aws.Int16(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: aws.Int32(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: aws.Int64(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: aws.Float32(123.456), + expectedOutput: &types.AttributeValueMemberN{Value: "123.456"}, + expectedError: false, + }, + { + input: aws.Float64(123.456), + expectedOutput: &types.AttributeValueMemberN{Value: "123.456"}, + expectedError: false, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var actualError error + var actualOutput any + + switch v := c.input.(type) { + case *uint: + cvt := NumericPtrConverter[uint]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case *uint8: + cvt := NumericPtrConverter[uint8]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case *uint16: + cvt := NumericPtrConverter[uint16]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case *uint32: + cvt := NumericPtrConverter[uint32]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case *uint64: + cvt := NumericPtrConverter[uint64]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case *int: + cvt := NumericPtrConverter[int]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case *int8: + cvt := NumericPtrConverter[int8]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case *int16: + cvt := NumericPtrConverter[int16]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case *int32: + cvt := NumericPtrConverter[int32]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case *int64: + cvt := NumericPtrConverter[int64]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case *float32: + cvt := NumericPtrConverter[float32]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case *float64: + cvt := NumericPtrConverter[float64]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + default: + t.Errorf("unsupported type: %T", v) + } + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/converters/numeric_test.go b/feature/dynamodb/entitymanager/converters/numeric_test.go new file mode 100644 index 00000000000..fefba1a7c8e --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/numeric_test.go @@ -0,0 +1,347 @@ +package converters + +import ( + "reflect" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestNumericConverter_FromAttributeValue(t *testing.T) { + cases := []struct { + input types.AttributeValue + options []string + expectedOutput any + expectedError bool + }{ + { + input: &types.AttributeValueMemberN{Value: "123"}, + expectedOutput: uint(123), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "-123"}, + expectedOutput: uint(18446744073709551493), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: uint(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123"}, + expectedOutput: uint8(123), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "1234"}, + expectedOutput: uint8(210), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "-1234"}, + expectedOutput: uint8(46), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: uint8(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123"}, + expectedOutput: uint16(123), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "123456"}, + expectedOutput: uint16(57920), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: uint16(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123"}, + expectedOutput: uint32(123), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "12345678901"}, + expectedOutput: uint32(3755744309), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: uint32(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123"}, + expectedOutput: uint64(123), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "123456789012345678901234567890123"}, + expectedOutput: uint64(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: uint64(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123.10"}, + expectedOutput: float32(123.10), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "-123.10"}, + expectedOutput: float32(-123.10), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "123456789012345678901234567890123"}, + expectedOutput: float32(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123,10"}, + expectedOutput: float32(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: float32(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123.10"}, + expectedOutput: float64(123.10), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "-123.10"}, + expectedOutput: float64(-123.10), + expectedError: false, + }, + { + input: &types.AttributeValueMemberN{Value: "123456789012345678901234567890123"}, + expectedOutput: float64(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "123,10"}, + expectedOutput: float64(0), + expectedError: true, + }, + { + input: &types.AttributeValueMemberN{Value: "test"}, + expectedOutput: float64(0), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var actualError error + var actualOutput any + + switch v := c.expectedOutput.(type) { + case uint: + cvt := NumericConverter[uint]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case uint8: + cvt := NumericConverter[uint8]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case uint16: + cvt := NumericConverter[uint16]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case uint32: + cvt := NumericConverter[uint32]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case uint64: + cvt := NumericConverter[uint64]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case int: + cvt := NumericConverter[int]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case int8: + cvt := NumericConverter[int8]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case int16: + cvt := NumericConverter[int16]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case int32: + cvt := NumericConverter[int32]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case int64: + cvt := NumericConverter[int64]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case float32: + cvt := NumericConverter[float32]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + case float64: + cvt := NumericConverter[float64]{} + actualOutput, actualError = cvt.FromAttributeValue(c.input, c.options) + default: + t.Errorf("unsupported type: %T", v) + } + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} + +func TestNumericConverter_ToAttributeValue(t *testing.T) { + cases := []struct { + input any + options []string + expectedOutput types.AttributeValue + expectedError bool + }{ + { + input: uint(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: uint8(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: uint16(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: uint32(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: uint64(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: int(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: int8(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: int16(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: int32(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: int64(123), + expectedOutput: &types.AttributeValueMemberN{Value: "123"}, + expectedError: false, + }, + { + input: float32(123.456), + expectedOutput: &types.AttributeValueMemberN{Value: "123.456"}, + expectedError: false, + }, + { + input: float64(123.456), + expectedOutput: &types.AttributeValueMemberN{Value: "123.456"}, + expectedError: false, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var actualError error + var actualOutput any + + switch v := c.input.(type) { + case uint: + cvt := NumericConverter[uint]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case uint8: + cvt := NumericConverter[uint8]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case uint16: + cvt := NumericConverter[uint16]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case uint32: + cvt := NumericConverter[uint32]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case uint64: + cvt := NumericConverter[uint64]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case int: + cvt := NumericConverter[int]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case int8: + cvt := NumericConverter[int8]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case int16: + cvt := NumericConverter[int16]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case int32: + cvt := NumericConverter[int32]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case int64: + cvt := NumericConverter[int64]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case float32: + cvt := NumericConverter[float32]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + case float64: + cvt := NumericConverter[float64]{} + actualOutput, actualError = cvt.ToAttributeValue(v, c.options) + default: + t.Errorf("unsupported type: %T", v) + } + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/converters/options.go b/feature/dynamodb/entitymanager/converters/options.go new file mode 100644 index 00000000000..ff5f0a128a8 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/options.go @@ -0,0 +1,18 @@ +package converters + +import ( + "fmt" + "strings" +) + +func getOpt(opts []string, name string) string { + p := fmt.Sprintf("%s=", name) + + for _, opt := range opts { + if strings.HasPrefix(opt, p) { + return opt[len(p):] + } + } + + return "" +} diff --git a/feature/dynamodb/entitymanager/converters/options_test.go b/feature/dynamodb/entitymanager/converters/options_test.go new file mode 100644 index 00000000000..31b60ad3651 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/options_test.go @@ -0,0 +1,46 @@ +package converters + +import ( + "strconv" + "testing" +) + +func TestGetOpt(t *testing.T) { + cases := []struct { + opts []string + name string + expected string + }{ + {}, + { + opts: []string{"test"}, + name: "test", + expected: "", + }, + { + opts: []string{"test="}, + name: "test", + expected: "", + }, + { + opts: []string{"test=test"}, + name: "test", + expected: "test", + }, + { + opts: []string{"TEST=test"}, + name: "test", + expected: "", + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + actual := getOpt(c.opts, c.name) + + if actual != c.expected { + t.Fatalf(`expected "%s", got "%s"`, c.expected, actual) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/converters/registry.go b/feature/dynamodb/entitymanager/converters/registry.go new file mode 100644 index 00000000000..ebbc97b91e5 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/registry.go @@ -0,0 +1,149 @@ +package converters + +import ( + "time" +) + +// DefaultRegistry is a pre-populated Registry containing converters for +// primitive numeric and boolean types, their pointer forms, time.Time and +// *time.Time, byte slices ([]byte / []uint8), strings (string / *string), and +// a generic JSON converter ("json"). +// +// The keys in the registry map are the string representations returned by the +// internal getType helper (e.g. "int", "*int", "time.Time", "*time.Time"). +// +// DefaultRegistry is intended as a starting point. You may call Clone() to +// obtain an isolated copy and then Add or Remove converters without affecting +// the shared defaults. Direct mutation of DefaultRegistry (calling Add or +// Remove on it) is safe only if done during program initialization; concurrent +// writes without external synchronization are not supported. +var DefaultRegistry = &Registry{ + converters: map[string]AnyAttributeConverter{ + // numbers + "uint": &Wrapper[uint]{Impl: &NumericConverter[uint]{}}, + "uint8": &Wrapper[uint8]{Impl: &NumericConverter[uint8]{}}, + "uint16": &Wrapper[uint16]{Impl: &NumericConverter[uint16]{}}, + "uint32": &Wrapper[uint32]{Impl: &NumericConverter[uint32]{}}, + "uint64": &Wrapper[uint64]{Impl: &NumericConverter[uint64]{}}, + "int": &Wrapper[int]{Impl: &NumericConverter[int]{}}, + "int8": &Wrapper[int8]{Impl: &NumericConverter[int8]{}}, + "int16": &Wrapper[int16]{Impl: &NumericConverter[int16]{}}, + "int32": &Wrapper[int32]{Impl: &NumericConverter[int32]{}}, + "int64": &Wrapper[int64]{Impl: &NumericConverter[int64]{}}, + "float32": &Wrapper[float32]{Impl: &NumericConverter[float32]{}}, + "float64": &Wrapper[float64]{Impl: &NumericConverter[float64]{}}, + // numbers pointers + "*uint": &Wrapper[*uint]{Impl: &NumericPtrConverter[uint]{}}, + "*uint8": &Wrapper[*uint8]{Impl: &NumericPtrConverter[uint8]{}}, + "*uint16": &Wrapper[*uint16]{Impl: &NumericPtrConverter[uint16]{}}, + "*uint32": &Wrapper[*uint32]{Impl: &NumericPtrConverter[uint32]{}}, + "*uint64": &Wrapper[*uint64]{Impl: &NumericPtrConverter[uint64]{}}, + "*int": &Wrapper[*int]{Impl: &NumericPtrConverter[int]{}}, + "*int8": &Wrapper[*int8]{Impl: &NumericPtrConverter[int8]{}}, + "*int16": &Wrapper[*int16]{Impl: &NumericPtrConverter[int16]{}}, + "*int32": &Wrapper[*int32]{Impl: &NumericPtrConverter[int32]{}}, + "*int64": &Wrapper[*int64]{Impl: &NumericPtrConverter[int64]{}}, + "*float32": &Wrapper[*float32]{Impl: &NumericPtrConverter[float32]{}}, + "*float64": &Wrapper[*float64]{Impl: &NumericPtrConverter[float64]{}}, + // other + "bool": &Wrapper[bool]{Impl: &BoolConverter{}}, + "*bool": &Wrapper[*bool]{Impl: &BoolPtrConverter{}}, + "[]uint8": &Wrapper[[]uint8]{Impl: &ByteArrayConverter{}}, + "[]byte": &Wrapper[[]byte]{Impl: &ByteArrayConverter{}}, + "string": &Wrapper[string]{Impl: &StringConverter{}}, + "*string": &Wrapper[*string]{Impl: &StringPtrConverter{}}, + "time.Time": &Wrapper[time.Time]{Impl: &TimeConverter{}}, + "*time.Time": &Wrapper[*time.Time]{Impl: &TimePtrConverter{}}, + "json": JSONConverter{}, + }, +} + +// Registry maintains a mapping from a string type key to an AnyAttributeConverter. +// +// It is primarily used by the DynamoDB entity manager to look up conversion +// strategies for Go values when serializing to / deserializing from +// DynamoDB AttributeValue types. +// +// Concurrency: A Registry is safe for concurrent read access provided no +// goroutine is mutating it. Methods that modify internal state (Add, Remove) +// are not synchronized. To customize converters at runtime without racing, +// create an independent instance using NewRegistry or Clone and mutate that +// instance before sharing it for read-only use. +// +// Keys: Converter lookup keys are the canonical type names produced by the +// internal getType helper (e.g., "int", "*int", "time.Time"). When adding +// custom converters, ensure the key matches what getType(value) would return +// for values you expect to convert. +// +// Zero value: The zero value of Registry (var r Registry) functions correctly; +// maps are lazily allocated on first Add/Converter call. +type Registry struct { + //defaultConverter AnyAttributeConverter + converters map[string]AnyAttributeConverter +} + +// Clone creates a deep copy of the Registry's converter mapping. Converters +// themselves are not copied (the underlying converter implementations are +// assumed to be stateless or safely shareable); only the map entries are +// duplicated. The returned Registry can be mutated independently of the +// original. +func (cr *Registry) Clone() *Registry { + r := &Registry{ + converters: make(map[string]AnyAttributeConverter), + } + + for k, v := range cr.converters { + r.converters[k] = v + } + + return r +} + +// Add registers (or replaces) a converter under the provided name and returns +// the Registry for fluent chaining. If the internal map has not yet been +// allocated it will be created. Replacing an existing converter is a silent +// overwrite. +func (cr *Registry) Add(name string, converter AnyAttributeConverter) *Registry { + if cr.converters == nil { + cr.converters = make(map[string]AnyAttributeConverter) + } + + cr.converters[name] = converter + + return cr +} + +// Remove deletes a converter by name and returns the Registry for fluent +// chaining. Removing a non-existent key is a no-op. +func (cr *Registry) Remove(name string) *Registry { + delete(cr.converters, name) + + return cr +} + +// Converter returns the converter registered for the given name. If the map +// has not yet been allocated it will be created (resulting in a nil return +// unless an entry was previously added). If no converter is found, nil is +// returned. +func (cr *Registry) Converter(name string) AnyAttributeConverter { + if cr.converters == nil { + cr.converters = make(map[string]AnyAttributeConverter) + } + + return cr.converters[name] +} + +// ConverterFor performs a lookup using the canonical string key derived from +// the dynamic type of x (via internal getType). If x is nil or its type has +// no registered converter, nil is returned. +func (cr *Registry) ConverterFor(x any) AnyAttributeConverter { + return cr.Converter(getType(x)) +} + +// NewRegistry constructs an empty Registry with an allocated converter map. +// Use Add to populate converters or copy from DefaultRegistry via Clone. +func NewRegistry() *Registry { + return &Registry{ + converters: make(map[string]AnyAttributeConverter), + } +} diff --git a/feature/dynamodb/entitymanager/converters/registry_test.go b/feature/dynamodb/entitymanager/converters/registry_test.go new file mode 100644 index 00000000000..2a47f02b7eb --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/registry_test.go @@ -0,0 +1,93 @@ +package converters + +import ( + "strconv" + "testing" +) + +func TestNewRegistry(t *testing.T) { + r := NewRegistry() + if r == nil { + t.Fatal("NewRegistry returned nil") + } + if r.converters == nil { + t.Fatal("NewRegistry did not initialize converters map") + } +} + +func TestRegistry_Clone(t *testing.T) { + r := DefaultRegistry.Clone() + if r == nil { + t.Fatal("Clone returned nil") + } + if r.converters == nil { + t.Fatal("Clone did not initialize converters map") + } + if len(r.converters) != len(DefaultRegistry.converters) { + t.Errorf("Clone did not copy all converters: got %d, want %d", len(r.converters), len(DefaultRegistry.converters)) + } +} + +func TestRegistry_Add(t *testing.T) { + r := &Registry{} + initial := len(r.converters) + r.Add("mock", &mockConverter{}) + if len(r.converters) != initial+1 { + t.Errorf("Add did not increase converter count: got %d, want %d", len(r.converters), initial+1) + } + if r.Converter("mock") == nil { + t.Error("Add did not register converter under 'mock'") + } +} + +func TestRegistry_Remove(t *testing.T) { + r := DefaultRegistry.Clone() + initial := len(r.converters) + r.Remove("json") + if len(r.converters) != initial-1 { + t.Errorf("Remove did not decrease converter count: got %d, want %d", len(r.converters), initial-1) + } + if r.Converter("json") != nil { + t.Error("Remove did not remove 'json' converter") + } +} + +func TestRegistry_Converter(t *testing.T) { + r := DefaultRegistry.Clone() + cases := []struct { + name string + ok bool + }{ + {"404", false}, + {"json", true}, + } + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + conv := r.Converter(c.name) + if (conv != nil) != c.ok { + t.Errorf("Converter(%q) presence = %v, want %v", c.name, conv != nil, c.ok) + } + }) + } +} + +func TestRegistry_ConverterFor(t *testing.T) { + r := DefaultRegistry.Clone() + + t.Run("known type", func(t *testing.T) { + var s string + conv := r.ConverterFor(s) + if conv == nil { + t.Errorf("expected converter for string, got nil") + } + }) + + t.Run("unknown type", func(t *testing.T) { + type custom struct{} + var c custom + conv := r.ConverterFor(c) + if conv != nil { + t.Errorf("expected no converter for custom type, got one") + } + }) +} diff --git a/feature/dynamodb/entitymanager/converters/string.go b/feature/dynamodb/entitymanager/converters/string.go new file mode 100644 index 00000000000..00fa4bc819c --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/string.go @@ -0,0 +1,49 @@ +package converters + +import ( + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +var _ AttributeConverter[string] = (*StringConverter)(nil) + +// StringConverter implements AttributeConverter for the Go string type. +// +// DynamoDB stores string values in AttributeValueMemberS. This converter: +// - Validates that the provided AttributeValue is of type *types.AttributeValueMemberS +// - Returns an error if a nil AttributeValueMemberS is encountered +// - Performs no special handling for empty strings ("" is a valid value) +// +// Tag options: Currently ignored. The parameter is accepted to keep a uniform +// signature with other converters that may use tag-derived options. +// +// Usage: +// +// conv := StringConverter{} +// av, _ := conv.ToAttributeValue("hello", nil) +// s, _ := conv.FromAttributeValue(av, nil) +// +// The converter is stateless and safe for concurrent use. +type StringConverter struct { +} + +// FromAttributeValue converts a DynamoDB string (S) AttributeValue to a Go string. +// Returns ErrNilValue for nil pointers, or unsupportedType for unsupported AttributeValue types. +func (n StringConverter) FromAttributeValue(v types.AttributeValue, _ []string) (string, error) { + switch av := v.(type) { + case *types.AttributeValueMemberS: + if av == nil { + return "", ErrNilValue + } + + return av.Value, nil + default: + return "", unsupportedType(v, (*types.AttributeValueMemberS)(nil)) + } +} + +// ToAttributeValue converts a Go string to a DynamoDB string (S) AttributeValue. +func (n StringConverter) ToAttributeValue(v string, _ []string) (types.AttributeValue, error) { + return &types.AttributeValueMemberS{ + Value: v, + }, nil +} diff --git a/feature/dynamodb/entitymanager/converters/string_ptr.go b/feature/dynamodb/entitymanager/converters/string_ptr.go new file mode 100644 index 00000000000..5872f7ba913 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/string_ptr.go @@ -0,0 +1,55 @@ +package converters + +import ( + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// Compile-time assertion that StringPtrConverter satisfies AttributeConverter[*string]. +var _ AttributeConverter[*string] = (*StringPtrConverter)(nil) + +// StringPtrConverter implements AttributeConverter for the Go *string type. +// +// DynamoDB stores string values in AttributeValueMemberS. This converter: +// - Validates that the provided AttributeValue is of type *types.AttributeValueMemberS +// - Returns ErrNilValue if the concrete *types.AttributeValueMemberS is nil +// - Distinguishes between a nil *string (error when marshaling) and a pointer to an empty string (valid value) +// +// Tag options: Currently ignored. The parameter is accepted to keep a uniform +// signature with other converters that may use tag-derived options. +// +// Usage: +// +// conv := StringPtrConverter{} +// val := "hello" +// av, _ := conv.ToAttributeValue(&val, nil) +// s, _ := conv.FromAttributeValue(av, nil) +// +// The converter is stateless and safe for concurrent use. +type StringPtrConverter struct{} + +// FromAttributeValue converts a DynamoDB string (S) AttributeValue to a Go *string. +// Returns ErrNilValue for nil pointers, or unsupportedType for unsupported AttributeValue types. +func (n StringPtrConverter) FromAttributeValue(v types.AttributeValue, _ []string) (*string, error) { + switch av := v.(type) { + case *types.AttributeValueMemberS: + if av == nil { + return nil, ErrNilValue + } + + return &av.Value, nil + default: + return nil, unsupportedType(v, (*types.AttributeValueMemberS)(nil)) + } +} + +// ToAttributeValue converts a Go *string to a DynamoDB string (S) AttributeValue. +// Returns ErrNilValue if the input pointer is nil. +func (n StringPtrConverter) ToAttributeValue(v *string, _ []string) (types.AttributeValue, error) { + if v == nil { + return nil, ErrNilValue + } + + return &types.AttributeValueMemberS{ + Value: *v, + }, nil +} diff --git a/feature/dynamodb/entitymanager/converters/string_ptr_test.go b/feature/dynamodb/entitymanager/converters/string_ptr_test.go new file mode 100644 index 00000000000..26285cded1d --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/string_ptr_test.go @@ -0,0 +1,115 @@ +package converters + +import ( + "reflect" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestStringPtrConverter_FromAttributeValue(t *testing.T) { + cases := []struct { + input types.AttributeValue + options []string + expectedOutput any + expectedError bool + }{ + { + input: &types.AttributeValueMemberS{}, + expectedOutput: aws.String(""), + }, + { + input: &types.AttributeValueMemberS{ + Value: "test", + }, + expectedOutput: aws.String("test"), + }, + { + input: &types.AttributeValueMemberB{}, + expectedError: true, + }, + { + input: (*types.AttributeValueMemberB)(nil), + expectedError: true, + }, + { + input: (*types.AttributeValueMemberS)(nil), + expectedError: true, + }, + { + input: nil, + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + cvt := &StringPtrConverter{} + actualOutput, actualError := cvt.FromAttributeValue(c.input, c.options) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} + +func TestStringPtrConverter_ToAttributeValue(t *testing.T) { + cases := []struct { + input *string + options []string + expectedOutput types.AttributeValue + expectedError bool + }{ + { + input: aws.String(""), + expectedOutput: &types.AttributeValueMemberS{}, + }, + { + input: aws.String("test"), + expectedOutput: &types.AttributeValueMemberS{ + Value: "test", + }, + }, + { + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + cvt := &StringPtrConverter{} + actualOutput, actualError := cvt.ToAttributeValue(c.input, c.options) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/converters/string_test.go b/feature/dynamodb/entitymanager/converters/string_test.go new file mode 100644 index 00000000000..830c0c8a5e9 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/string_test.go @@ -0,0 +1,111 @@ +package converters + +import ( + "reflect" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestStringConverter_FromAttributeValue(t *testing.T) { + cases := []struct { + input types.AttributeValue + options []string + expectedOutput any + expectedError bool + }{ + { + input: &types.AttributeValueMemberS{}, + expectedOutput: "", + }, + { + input: &types.AttributeValueMemberS{ + Value: "test", + }, + expectedOutput: "test", + }, + { + input: &types.AttributeValueMemberB{}, + expectedError: true, + }, + { + input: (*types.AttributeValueMemberB)(nil), + expectedError: true, + }, + { + input: (*types.AttributeValueMemberS)(nil), + expectedError: true, + }, + { + input: nil, + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + cvt := &StringConverter{} + actualOutput, actualError := cvt.FromAttributeValue(c.input, c.options) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} + +func TestStringConverter_ToAttributeValue(t *testing.T) { + cases := []struct { + input string + options []string + expectedOutput types.AttributeValue + expectedError bool + }{ + { + input: "", + expectedOutput: &types.AttributeValueMemberS{}, + }, + { + input: "test", + expectedOutput: &types.AttributeValueMemberS{ + Value: "test", + }, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + cvt := &StringConverter{} + actualOutput, actualError := cvt.ToAttributeValue(c.input, c.options) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/converters/time.go b/feature/dynamodb/entitymanager/converters/time.go new file mode 100644 index 00000000000..7d309cb6dca --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/time.go @@ -0,0 +1,202 @@ +package converters + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +var _ AttributeConverter[time.Time] = (*TimeConverter)(nil) + +// taken from go src/time/format.go +var knowTimeFormats = []string{ + time.Layout, //= "01/02 03:04:05PM '06 -0700" // The reference time, in numerical order. + time.ANSIC, //= "Mon Jan _2 15:04:05 2006" + time.UnixDate, //= "Mon Jan _2 15:04:05 MST 2006" + time.RubyDate, //= "Mon Jan 02 15:04:05 -0700 2006" + time.RFC822, //= "02 Jan 06 15:04 MST" + time.RFC822Z, //= "02 Jan 06 15:04 -0700" // RFC822 with numeric zone + time.RFC850, //= "Monday, 02-Jan-06 15:04:05 MST" + time.RFC1123, //= "Mon, 02 Jan 2006 15:04:05 MST" + time.RFC1123Z, //= "Mon, 02 Jan 2006 15:04:05 -0700" // RFC1123 with numeric zone + time.RFC3339, //= "2006-01-02T15:04:05Z07:00" + time.RFC3339Nano, //= "2006-01-02T15:04:05.999999999Z07:00" + time.Kitchen, //= "3:04PM" + time.Stamp, //= "Jan _2 15:04:05" + time.StampMilli, //= "Jan _2 15:04:05.000" + time.StampMicro, //= "Jan _2 15:04:05.000000" + time.StampNano, //= "Jan _2 15:04:05.000000000" + time.DateTime, //= "2006-01-02 15:04:05" + time.DateOnly, //= "2006-01-02" + time.TimeOnly, //= "15:04:05" +} + +// defaultTimeFormat is the layout used when encoding time values as strings +// if no explicit "format" option is supplied. +var defaultTimeFormat = time.RFC3339Nano + +// TimeConverter implements AttributeConverter for time.Time values. +// +// Supported DynamoDB representations: +// - String (AttributeValueMemberS) using a layout provided via the "format" option or any of knowTimeFormats. +// - Number (AttributeValueMemberN) encoding seconds and optional fractional nanoseconds as "[.]". +// +// Options (case-sensitive keys): +// +// format: Custom time.Parse / time.Format layout string (Go reference layout). Overrides fallback list when decoding. +// TZ: IANA timezone name (e.g., "UTC", "America/New_York") applied after successful parse when decoding. +// as: When encoding, chooses representation: "string" (default) or "number". +// +// Zero value handling: Decoding to a zero time returns the zero value (not an error). There is +// no nil sentinel for non-pointer time.Time values. +// +// Fractional seconds when encoding as number are trimmed of trailing zeros. +// The converter is stateless and safe for concurrent use. +type TimeConverter struct{} + +// FromAttributeValue converts a DynamoDB AttributeValue into a time.Time using either +// string or numeric representations. +// +// Decoding logic: +// +// String: Attempts parse with provided "format" opt if present, else tries knowTimeFormats in order. +// On parse failure across all formats returns an error listing attempted formats. +// Number: Splits on '.', first part seconds, second (optional) part nanoseconds (truncated to 9 digits). +// Rejects more than one '.' (i.e., len(parts) > 2). +// +// Options: +// +// format: Single layout used instead of fallback list (string input only). +// TZ: IANA timezone applied post-parse; errors if unknown. +// +// Error cases: +// - Nil underlying AttributeValueMemberS/N -> ErrNilValue +// - Unsupported AttributeValue type -> unsupportedType error +// - Invalid numeric format or parse errors -> descriptive errors +// - Unknown timezone -> error +// +// Returns zero time if parsed value is zero. Zero time is treated as absence but +// still returned (not an error). Consumers may check t.IsZero(). +func (tc TimeConverter) FromAttributeValue(v types.AttributeValue, opts []string) (time.Time, error) { + t := time.Time{} + + switch av := v.(type) { + case *types.AttributeValueMemberS: + // when calling with v = (*types.AttributeValueMemberS)(nil) then v == nil is false + // e.g. tc.FromAttributeValue((*types.AttributeValueMemberS)(nil), nil) -> panics + if av == nil { + return time.Time{}, ErrNilValue + } + + format := getOpt(opts, "format") + var formats []string + if format != "" { + formats = []string{format} + } else { + formats = knowTimeFormats + } + + var err error + for _, f := range formats { + t, err = time.Parse(f, av.Value) + if err == nil { + break + } + } + + // err will be populated only if all time.Parse() attempts returned an error + if err != nil { + return time.Time{}, fmt.Errorf("unable to process time %s with format(s): %v", av.Value, formats) + } + case *types.AttributeValueMemberN: + // when calling with v = (*types.AttributeValueMemberN)(nil) then v == nil is false + // e.g. tc.FromAttributeValue((*types.AttributeValueMemberN)(nil), nil) -> panics + if av == nil { + return time.Time{}, ErrNilValue + } + + parts := strings.Split(av.Value, ".") + // format is "000" or "000.000", anything else is an issue + if len(parts) > 2 { + return time.Time{}, fmt.Errorf("unsupported format for number inside of types.AttributeValueMemberN: %v", av.Value) + } + + var err error + ps := make([]int64, 2) + + for i := range parts { + // microseconds can be at most 9 chars long, otherwise they overflow into the seconds part + if i == 1 && len(parts[i]) > 9 { + parts[i] = parts[i][0:9] + } + ps[i], err = strconv.ParseInt(parts[i], 10, 64) + if err != nil { + return time.Time{}, fmt.Errorf("error parsing int: %v", parts[i]) + } + } + + t = time.Unix(ps[0], ps[1]) + default: + return time.Time{}, unsupportedType( + v, + (*types.AttributeValueMemberS)(nil), + (*types.AttributeValueMemberN)(nil), + ) + } + + if t.IsZero() { + return time.Time{}, nil + } + + if tz := getOpt(opts, "TZ"); tz != "" { + loc, err := time.LoadLocation(tz) + if err != nil { + return time.Time{}, fmt.Errorf(`error loading timezone "%s" data: %v`, tz, err) + } + t = t.In(loc) + } + + return t, nil +} + +// ToAttributeValue converts a time.Time into a DynamoDB AttributeValue. +// +// Options: +// +// as: "number" -> encodes seconds[.nanos] in AttributeValueMemberN, trimming trailing zeros in nanos. +// "string" or empty -> encodes formatted string using layout from "format" opt or defaultTimeFormat. +// format: Go layout string used when encoding as string (ignored for number encoding). Falls back to defaultTimeFormat. +// +// Errors: +// - Unknown "as" value -> error listing expected values +// +// Returned AttributeValue will be *types.AttributeValueMemberS or *types.AttributeValueMemberN depending on representation. +func (tc TimeConverter) ToAttributeValue(v time.Time, opts []string) (types.AttributeValue, error) { + as := getOpt(opts, "as") + format := getOpt(opts, "format") + if format == "" { + format = defaultTimeFormat + } + + switch as { + case "number": + parts := []string{ + fmt.Sprintf("%v", v.Unix()), + } + if v.Nanosecond() != 0 { + parts = append(parts, strings.TrimRight(fmt.Sprintf("%v", v.Nanosecond()), "0")) + } + return &types.AttributeValueMemberN{ + Value: strings.Join(parts, "."), + }, nil + case "string", "": + return &types.AttributeValueMemberS{ + Value: v.Format(format), + }, nil + default: + return nil, fmt.Errorf(`unknown value for time format: expected "", "string" or "number", got "%v"`, as) + } +} diff --git a/feature/dynamodb/entitymanager/converters/time_ptr.go b/feature/dynamodb/entitymanager/converters/time_ptr.go new file mode 100644 index 00000000000..6d2e43eeb85 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/time_ptr.go @@ -0,0 +1,180 @@ +package converters + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +var _ AttributeConverter[*time.Time] = (*TimePtrConverter)(nil) + +// TimePtrConverter implements AttributeConverter for *time.Time values. +// +// Supported DynamoDB representations: +// - String (AttributeValueMemberS) using a layout provided via the "format" option or any of knowTimeFormats. +// - Number (AttributeValueMemberN) encoding seconds and optional fractional nanoseconds as "[.]". +// +// Options (case-sensitive keys): +// +// format: Custom time.Parse / time.Format layout string (Go reference layout). Overrides fallback list when decoding. +// TZ: IANA timezone name (e.g., "UTC", "America/New_York") applied after successful parse when decoding. +// as: When encoding, chooses representation: "string" (default) or "number". +// +// Nil handling: +// - Encoding a nil *time.Time returns ErrNilValue. +// - Decoding a zero time (t.IsZero()) returns (nil, nil) signaling absence. +// +// Fractional seconds when encoding as number are trimmed of trailing zeros. +// The converter is stateless and safe for concurrent use. +type TimePtrConverter struct{} + +// FromAttributeValue converts a DynamoDB AttributeValue into a *time.Time using +// either string or numeric representations. +// +// Decoding logic: +// +// String: Attempts parse with provided "format" opt if present, else tries knowTimeFormats in order. +// On parse failure across all formats returns an error listing attempted formats. +// Number: Splits on '.', first part seconds, second (optional) part nanoseconds (truncated to 9 digits). +// Rejects more than one '.' (i.e., len(parts) > 2). +// +// Options: +// +// format: Single layout used instead of fallback list (string input only). +// TZ: IANA timezone applied post-parse; errors if unknown. +// +// Error cases: +// - Nil underlying AttributeValueMemberS/N -> ErrNilValue +// - Unsupported AttributeValue type -> unsupportedType error +// - Invalid numeric format or parse errors -> descriptive errors +// - Unknown timezone -> error +// +// Returns (nil, nil) if parsed time is the zero value. +func (tc TimePtrConverter) FromAttributeValue(v types.AttributeValue, opts []string) (*time.Time, error) { + t := time.Time{} + + switch av := v.(type) { + case *types.AttributeValueMemberS: + // when calling with v = (*types.AttributeValueMemberS)(nil) then v == nil is false + // e.g. tc.FromAttributeValue((*types.AttributeValueMemberS)(nil), nil) -> panics + if av == nil { + return nil, ErrNilValue + } + + format := getOpt(opts, "format") + var formats []string + if format != "" { + formats = []string{format} + } else { + formats = knowTimeFormats + } + + var err error + for _, f := range formats { + t, err = time.Parse(f, av.Value) + if err == nil { + break + } + } + + // err will be populated only if all time.Parse() attempts returned an error + if err != nil { + return nil, fmt.Errorf("unable to process time %s with format(s): %v", av.Value, formats) + } + case *types.AttributeValueMemberN: + // when calling with v = (*types.AttributeValueMemberN)(nil) then v == nil is false + // e.g. tc.FromAttributeValue((*types.AttributeValueMemberN)(nil), nil) -> panics + if av == nil { + return nil, ErrNilValue + } + + parts := strings.Split(av.Value, ".") + // format is "000" or "000.000", anything else is an issue + if len(parts) > 2 { + return nil, fmt.Errorf("unsupported format for number inside of types.AttributeValueMemberN: %v", av.Value) + } + + var err error + ps := make([]int64, 2) + + for i := range parts { + // microseconds can be at most 9 chars long, otherwise they overflow into the seconds part + if i == 1 && len(parts[i]) > 9 { + parts[i] = parts[i][0:9] + } + ps[i], err = strconv.ParseInt(parts[i], 10, 64) + if err != nil { + return nil, fmt.Errorf("error parsing int: %v", parts[i]) + } + } + + t = time.Unix(ps[0], ps[1]) + default: + return nil, unsupportedType( + v, + (*types.AttributeValueMemberS)(nil), + (*types.AttributeValueMemberN)(nil), + ) + } + + if t.IsZero() { + return nil, nil + } + + if tz := getOpt(opts, "TZ"); tz != "" { + loc, err := time.LoadLocation(tz) + if err != nil { + return nil, fmt.Errorf(`error loading timezone "%s" data: %v`, tz, err) + } + t = t.In(loc) + } + + return &t, nil +} + +// ToAttributeValue converts a *time.Time into a DynamoDB AttributeValue. +// +// Options: +// +// as: "number" -> encodes seconds[.nanos] in AttributeValueMemberN, trimming trailing zeros in nanos. +// "string" or empty -> encodes formatted string using layout from "format" opt or defaultTimeFormat. +// format: Go layout string used when encoding as string (ignored for number encoding). Falls back to defaultTimeFormat. +// +// Errors: +// - Nil input -> ErrNilValue +// - Unknown "as" value -> error listing expected values +// +// Returned AttributeValue will be *types.AttributeValueMemberS or *types.AttributeValueMemberN depending on representation. +func (tc TimePtrConverter) ToAttributeValue(v *time.Time, opts []string) (types.AttributeValue, error) { + if v == nil { + return nil, ErrNilValue + } + + as := getOpt(opts, "as") + format := getOpt(opts, "format") + if format == "" { + format = defaultTimeFormat + } + + switch as { + case "number": + parts := []string{ + fmt.Sprintf("%v", v.Unix()), + } + if v.Nanosecond() != 0 { + parts = append(parts, strings.TrimRight(fmt.Sprintf("%v", v.Nanosecond()), "0")) + } + return &types.AttributeValueMemberN{ + Value: strings.Join(parts, "."), + }, nil + case "string", "": + return &types.AttributeValueMemberS{ + Value: v.Format(format), + }, nil + default: + return nil, fmt.Errorf(`unknown time format: expected "", "string" or "number", got "%v"`, as) + } +} diff --git a/feature/dynamodb/entitymanager/converters/time_ptr_test.go b/feature/dynamodb/entitymanager/converters/time_ptr_test.go new file mode 100644 index 00000000000..aaa84a1809a --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/time_ptr_test.go @@ -0,0 +1,137 @@ +package converters + +import ( + "reflect" + "strconv" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestTimePtrConverter_FromAttributeValue(t *testing.T) { + cases := []struct { + input types.AttributeValue + opts []string + expectedOutput any + expectedError bool + }{ + {input: &types.AttributeValueMemberN{Value: "1136214245"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), expectedError: false}, + {input: &types.AttributeValueMemberN{Value: "1136214245.113621424"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 113621424).In(time.UTC)), expectedError: false}, + {input: &types.AttributeValueMemberS{Value: "01/02 03:04:05PM #06 +0000"}, opts: []string{"TZ=UTC", "format=01/02 03:04:05PM #06 -0700"}, expectedOutput: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), expectedError: false}, + //time.Layout + {input: &types.AttributeValueMemberS{Value: "01/02 03:04:05PM '06 +0000"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), expectedError: false}, + //time.ANSIC + {input: &types.AttributeValueMemberS{Value: "Mon Jan 2 15:04:05 2006"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), expectedError: false}, + //time.UnixDate + {input: &types.AttributeValueMemberS{Value: "Mon Jan 02 15:04:05 UTC 2006"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), expectedError: false}, + //time.RubyDate + {input: &types.AttributeValueMemberS{Value: "Mon Jan 02 15:04:05 +0000 2006"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), expectedError: false}, + //time.RFC822 + {input: &types.AttributeValueMemberS{Value: "02 Jan 06 15:04 UTC"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214240, 0).In(time.UTC)), expectedError: false}, + //time.RFC822Z + {input: &types.AttributeValueMemberS{Value: "02 Jan 06 15:04 +0000"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214240, 0).In(time.UTC)), expectedError: false}, + //time.RFC850 + {input: &types.AttributeValueMemberS{Value: "Monday, 02-Jan-06 15:04:05 UTC"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), expectedError: false}, + //time.RFC1123 + {input: &types.AttributeValueMemberS{Value: "Mon, 02 Jan 2006 15:04:05 UTC"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), expectedError: false}, + //time.RFC1123Z + {input: &types.AttributeValueMemberS{Value: "Mon, 02 Jan 2006 15:04:05 +0000"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), expectedError: false}, + //time.RFC3339 + {input: &types.AttributeValueMemberS{Value: "2006-01-02T15:04:05+00:00"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), expectedError: false}, + //time.RFC3339Nano + {input: &types.AttributeValueMemberS{Value: "2006-01-02T15:04:05.999999999+00:00"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 999999999).In(time.UTC)), expectedError: false}, + //time.Kitchen + {input: &types.AttributeValueMemberS{Value: "3:04PM"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(-62167164960, 0).In(time.UTC)), expectedError: false}, + //time.Stamp + {input: &types.AttributeValueMemberS{Value: "Jan 02 15:04:05"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(-62167078555, 0).In(time.UTC)), expectedError: false}, + //time.StampMilli + {input: &types.AttributeValueMemberS{Value: "Jan 02 15:04:05.000"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(-62167078555, 0).In(time.UTC)), expectedError: false}, + //time.StampMicro + {input: &types.AttributeValueMemberS{Value: "Jan 02 15:04:05.000000"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(-62167078555, 0).In(time.UTC)), expectedError: false}, + //time.StampNano + {input: &types.AttributeValueMemberS{Value: "Jan 02 15:04:05.000000000"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(-62167078555, 0).In(time.UTC)), expectedError: false}, + //time.DateTime + {input: &types.AttributeValueMemberS{Value: "2006-01-02 15:04:05"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), expectedError: false}, + //time.DateOnly + {input: &types.AttributeValueMemberS{Value: "2006-01-02"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(1136160000, 0).In(time.UTC)), expectedError: false}, + //time.TimeOnly + {input: &types.AttributeValueMemberS{Value: "15:04:05"}, opts: []string{"TZ=UTC"}, expectedOutput: aws.Time(time.Unix(-62167164955, 0).In(time.UTC)), expectedError: false}, + // errors + {input: nil, opts: nil, expectedOutput: nil, expectedError: true}, + {input: (types.AttributeValue)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberN)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberS)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberBOOL)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: &types.AttributeValueMemberBOOL{Value: true}, opts: nil, expectedOutput: nil, expectedError: true}, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + tc := TimePtrConverter{} + + actualOutput, actualError := tc.FromAttributeValue(c.input, c.opts) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} + +func TestTimePtrConverter_ToAttributeValue(t *testing.T) { + cases := []struct { + input *time.Time + opts []string + expectedOutput types.AttributeValue + expectedError bool + }{ + {input: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), opts: []string{"TZ=UTC"}, expectedOutput: &types.AttributeValueMemberS{Value: time.Unix(1136214245, 0).In(time.UTC).Format(time.RFC3339)}, expectedError: false}, + {input: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), opts: []string{"TZ=UTC", "as=number"}, expectedOutput: &types.AttributeValueMemberN{Value: "1136214245"}, expectedError: false}, + {input: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), opts: []string{"TZ=UTC", "as=string"}, expectedOutput: &types.AttributeValueMemberS{Value: time.Unix(1136214245, 0).In(time.UTC).Format(time.RFC3339)}, expectedError: false}, + {input: aws.Time(time.Unix(1136214245, 113621424).In(time.UTC)), opts: []string{"TZ=UTC", "as=number"}, expectedOutput: &types.AttributeValueMemberN{Value: "1136214245.113621424"}, expectedError: false}, + {input: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), opts: []string{"TZ=UTC", "format=01/02 03:04:05PM #06 -0700"}, expectedOutput: &types.AttributeValueMemberS{Value: "01/02 03:04:05PM #06 +0000"}, expectedError: false}, + {input: aws.Time(time.Unix(1136214245, 0).In(time.UTC)), opts: []string{"TZ=UTC", "format=01/02 03:04:05PM 06 -0700"}, expectedOutput: &types.AttributeValueMemberS{Value: "01/02 03:04:05PM 06 +0000"}, expectedError: false}, + {input: &time.Time{}, opts: nil, expectedOutput: &types.AttributeValueMemberS{Value: "0001-01-01T00:00:00Z"}, expectedError: false}, + // errors + {input: nil, opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*time.Time)(nil), opts: nil, expectedOutput: (types.AttributeValue)(nil), expectedError: true}, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + tc := TimePtrConverter{} + + actualOutput, actualError := tc.ToAttributeValue(c.input, c.opts) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/converters/time_test.go b/feature/dynamodb/entitymanager/converters/time_test.go new file mode 100644 index 00000000000..422d2b277fd --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/time_test.go @@ -0,0 +1,136 @@ +package converters + +import ( + "reflect" + "strconv" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestTimeConverter_FromAttributeValue(t *testing.T) { + cases := []struct { + input types.AttributeValue + opts []string + expectedOutput any + expectedError bool + }{ + {input: &types.AttributeValueMemberN{Value: "1136214245"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 0).In(time.UTC), expectedError: false}, + {input: &types.AttributeValueMemberN{Value: "1136214245.113621424"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 113621424).In(time.UTC), expectedError: false}, + {input: &types.AttributeValueMemberS{Value: "01/02 03:04:05PM #06 +0000"}, opts: []string{"TZ=UTC", "format=01/02 03:04:05PM #06 -0700"}, expectedOutput: time.Unix(1136214245, 0).In(time.UTC), expectedError: false}, + //time.Layout + {input: &types.AttributeValueMemberS{Value: "01/02 03:04:05PM '06 +0000"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 0).In(time.UTC), expectedError: false}, + //time.ANSIC + {input: &types.AttributeValueMemberS{Value: "Mon Jan 2 15:04:05 2006"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 0).In(time.UTC), expectedError: false}, + //time.UnixDate + {input: &types.AttributeValueMemberS{Value: "Mon Jan 02 15:04:05 UTC 2006"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 0).In(time.UTC), expectedError: false}, + //time.RubyDate + {input: &types.AttributeValueMemberS{Value: "Mon Jan 02 15:04:05 +0000 2006"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 0).In(time.UTC), expectedError: false}, + //time.RFC822 + {input: &types.AttributeValueMemberS{Value: "02 Jan 06 15:04 UTC"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214240, 0).In(time.UTC), expectedError: false}, + //time.RFC822Z + {input: &types.AttributeValueMemberS{Value: "02 Jan 06 15:04 +0000"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214240, 0).In(time.UTC), expectedError: false}, + //time.RFC850 + {input: &types.AttributeValueMemberS{Value: "Monday, 02-Jan-06 15:04:05 UTC"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 0).In(time.UTC), expectedError: false}, + //time.RFC1123 + {input: &types.AttributeValueMemberS{Value: "Mon, 02 Jan 2006 15:04:05 UTC"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 0).In(time.UTC), expectedError: false}, + //time.RFC1123Z + {input: &types.AttributeValueMemberS{Value: "Mon, 02 Jan 2006 15:04:05 +0000"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 0).In(time.UTC), expectedError: false}, + //time.RFC3339 + {input: &types.AttributeValueMemberS{Value: "2006-01-02T15:04:05+00:00"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 0).In(time.UTC), expectedError: false}, + //time.RFC3339Nano + {input: &types.AttributeValueMemberS{Value: "2006-01-02T15:04:05.999999999+00:00"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 999999999).In(time.UTC), expectedError: false}, + //time.Kitchen + {input: &types.AttributeValueMemberS{Value: "3:04PM"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(-62167164960, 0).In(time.UTC), expectedError: false}, + //time.Stamp + {input: &types.AttributeValueMemberS{Value: "Jan 02 15:04:05"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(-62167078555, 0).In(time.UTC), expectedError: false}, + //time.StampMilli + {input: &types.AttributeValueMemberS{Value: "Jan 02 15:04:05.000"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(-62167078555, 0).In(time.UTC), expectedError: false}, + //time.StampMicro + {input: &types.AttributeValueMemberS{Value: "Jan 02 15:04:05.000000"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(-62167078555, 0).In(time.UTC), expectedError: false}, + //time.StampNano + {input: &types.AttributeValueMemberS{Value: "Jan 02 15:04:05.000000000"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(-62167078555, 0).In(time.UTC), expectedError: false}, + //time.DateTime + {input: &types.AttributeValueMemberS{Value: "2006-01-02 15:04:05"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136214245, 0).In(time.UTC), expectedError: false}, + //time.DateOnly + {input: &types.AttributeValueMemberS{Value: "2006-01-02"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(1136160000, 0).In(time.UTC), expectedError: false}, + //time.TimeOnly + {input: &types.AttributeValueMemberS{Value: "15:04:05"}, opts: []string{"TZ=UTC"}, expectedOutput: time.Unix(-62167164955, 0).In(time.UTC), expectedError: false}, + // errors + {input: nil, opts: nil, expectedOutput: nil, expectedError: true}, + {input: (types.AttributeValue)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberN)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberS)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: (*types.AttributeValueMemberBOOL)(nil), opts: nil, expectedOutput: nil, expectedError: true}, + {input: &types.AttributeValueMemberBOOL{Value: true}, opts: nil, expectedOutput: nil, expectedError: true}, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + tc := TimeConverter{} + + actualOutput, actualError := tc.FromAttributeValue(c.input, c.opts) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} + +func TestTimeConverter_ToAttributeValue(t *testing.T) { + cases := []struct { + input time.Time + opts []string + expectedOutput types.AttributeValue + expectedError bool + }{ + {input: time.Unix(1136214245, 0).In(time.UTC), opts: []string{"TZ=UTC"}, expectedOutput: &types.AttributeValueMemberS{Value: time.Unix(1136214245, 0).In(time.UTC).Format(time.RFC3339)}, expectedError: false}, + {input: time.Unix(1136214245, 0).In(time.UTC), opts: []string{"TZ=UTC", "as=number"}, expectedOutput: &types.AttributeValueMemberN{Value: "1136214245"}, expectedError: false}, + {input: time.Unix(1136214245, 0).In(time.UTC), opts: []string{"TZ=UTC", "as=string"}, expectedOutput: &types.AttributeValueMemberS{Value: time.Unix(1136214245, 0).In(time.UTC).Format(time.RFC3339)}, expectedError: false}, + {input: time.Unix(1136214245, 113621424).In(time.UTC), opts: []string{"TZ=UTC", "as=number"}, expectedOutput: &types.AttributeValueMemberN{Value: "1136214245.113621424"}, expectedError: false}, + {input: time.Unix(1136214245, 0).In(time.UTC), opts: []string{"TZ=UTC", "format=01/02 03:04:05PM #06 -0700"}, expectedOutput: &types.AttributeValueMemberS{Value: "01/02 03:04:05PM #06 +0000"}, expectedError: false}, + {input: time.Unix(1136214245, 0).In(time.UTC), opts: []string{"TZ=UTC", "format=01/02 03:04:05PM 06 -0700"}, expectedOutput: &types.AttributeValueMemberS{Value: "01/02 03:04:05PM 06 +0000"}, expectedError: false}, + {input: time.Time{}, opts: nil, expectedOutput: &types.AttributeValueMemberS{Value: "0001-01-01T00:00:00Z"}, expectedError: false}, + // errors + //{input: nil, opts: nil, expectedOutput: nil, expectedError: true}, + //{input: (*time.Time)(nil), opts: nil, expectedOutput: (types.AttributeValue)(nil), expectedError: true}, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + tc := TimeConverter{} + + actualOutput, actualError := tc.ToAttributeValue(c.input, c.opts) + + if actualError == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if actualError != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", actualError) + } + + if actualError != nil && c.expectedError { + return + } + + if !reflect.DeepEqual(c.expectedOutput, actualOutput) { + t.Fatalf("%#+v != %#+v", c.expectedOutput, actualOutput) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/converters/types.go b/feature/dynamodb/entitymanager/converters/types.go new file mode 100644 index 00000000000..f7982475703 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/types.go @@ -0,0 +1,24 @@ +package converters + +import ( + "reflect" +) + +// getType returns the canonical string form of the dynamic type of x, +// as produced by reflect.TypeOf(x).String(). This value is used as the +// key in converter registries (e.g., "int", "*time.Time", "[]byte"). +// It is faster than fmt.Sprintf("%T", x) (avoids formatting machinery) +// and slightly faster than obtaining a generic type via reflect.TypeFor[T]. +// +// Benchmark (indicative only; values vary by Go version, hardware, flags): +// +// fmt.Sprintf(\"%T\", x) ~ slower +// reflect.TypeFor[T]() ~ medium +// reflect.TypeOf(x).String() ~ fastest in our measurements +// +// Caveat: Passing a nil interface (x == nil) will cause a panic because +// reflect.TypeOf(nil) == nil and the subsequent call to .String() dereferences nil. +// Callers must ensure x is non-nil. +func getType(x any) string { + return reflect.TypeOf(x).String() +} diff --git a/feature/dynamodb/entitymanager/converters/types_test.go b/feature/dynamodb/entitymanager/converters/types_test.go new file mode 100644 index 00000000000..bcd37d05e79 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/types_test.go @@ -0,0 +1,138 @@ +package converters + +import ( + "strconv" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" +) + +func TestGetType(t *testing.T) { + cases := []struct { + input any + expected string + }{ + { + input: uint(0), + expected: "uint", + }, + { + input: uint8(0), + expected: "uint8", + }, + { + input: uint16(0), + expected: "uint16", + }, + { + input: uint32(0), + expected: "uint32", + }, + { + input: uint64(0), + expected: "uint64", + }, + { + input: int(0), + expected: "int", + }, + { + input: int8(0), + expected: "int8", + }, + { + input: int16(0), + expected: "int16", + }, + { + input: int32(0), + expected: "int32", + }, + { + input: int64(0), + expected: "int64", + }, + { + input: float32(0), + expected: "float32", + }, + { + input: float64(0), + expected: "float64", + }, + { + input: aws.Uint(uint(0)), + expected: "*uint", + }, + { + input: aws.Uint8(uint8(0)), + expected: "*uint8", + }, + { + input: aws.Uint16(uint16(0)), + expected: "*uint16", + }, + { + input: aws.Uint32(uint32(0)), + expected: "*uint32", + }, + { + input: aws.Uint64(uint64(0)), + expected: "*uint64", + }, + { + input: aws.Int(int(0)), + expected: "*int", + }, + { + input: aws.Int8(int8(0)), + expected: "*int8", + }, + { + input: aws.Int16(int16(0)), + expected: "*int16", + }, + { + input: aws.Int32(int32(0)), + expected: "*int32", + }, + { + input: aws.Int64(int64(0)), + expected: "*int64", + }, + { + input: aws.Float32(float32(0)), + expected: "*float32", + }, + { + input: aws.Float64(float64(0)), + expected: "*float64", + }, + { + input: time.Time{}, + expected: "time.Time", + }, + { + input: &time.Time{}, + expected: "*time.Time", + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + actual := getType(c.input) + + if actual != c.expected { + t.Fatalf(`expected "%s", got "%s" for %T`, c.expected, actual, c.input) + } + }) + } +} + +func BenchmarkGetType(b *testing.B) { + x := int8(8) + for c := 0; c < b.N; c++ { + _ = getType(x) + } +} diff --git a/feature/dynamodb/entitymanager/converters/wrapper.go b/feature/dynamodb/entitymanager/converters/wrapper.go new file mode 100644 index 00000000000..d95ded0e8ea --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/wrapper.go @@ -0,0 +1,59 @@ +package converters + +import ( + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// AnyAttributeConverter is a non-generic abstraction for attribute converters. +// It provides a type-erased interface so heterogeneous converters can be stored +// together (e.g., in a map[string]AnyAttributeConverter) and invoked dynamically. +// Implementations must accept a slice of tag-derived options that may influence +// conversion behavior. +type AnyAttributeConverter interface { + FromAttributeValue(types.AttributeValue, []string) (any, error) + ToAttributeValue(any, []string) (types.AttributeValue, error) +} + +// AttributeConverter defines conversion logic between a concrete Go type T and DynamoDB +// AttributeValues. Implementations encode/decode a value of T, optionally using +// tag-derived options supplied via the []string parameter. +type AttributeConverter[T any] interface { + // FromAttributeValue converts a DynamoDB AttributeValue to the Go type T. + // The second argument provides tag options for the converter. + FromAttributeValue(types.AttributeValue, []string) (T, error) + // ToAttributeValue converts a value of type T to a DynamoDB AttributeValue. + ToAttributeValue(T, []string) (types.AttributeValue, error) +} + +// Wrapper adapts an AttributeConverter[T] to the non-generic AnyAttributeConverter +// interface so converters for different concrete types can coexist in the same +// registry structure. Without this indirection, generic constraints would prevent +// a uniform collection (e.g. map[string]AttributeConverter[T]) spanning multiple T. +// +// Wrapper assumes the wrapped converter's methods are safe for concurrent use. +// It performs a runtime type assertion when converting values back to AttributeValue. +// If the provided value does not match T (and is not nil) an unsupportedType error +// is returned. +type Wrapper[T any] struct { + Impl AttributeConverter[T] +} + +// FromAttributeValue delegates to the underlying AttributeConverter[T] and returns +// the resulting value boxed as any. Tag-derived options are forwarded unchanged. +// Errors from the underlying converter are propagated as-is. +func (w *Wrapper[T]) FromAttributeValue(attr types.AttributeValue, opts []string) (any, error) { + return w.Impl.FromAttributeValue(attr, opts) +} + +// ToAttributeValue attempts to cast the supplied value to T (allowing nil) and +// delegates to the underlying converter. If the dynamic type does not match T, +// unsupportedType is returned. A nil value is passed through as the zero value +// of T; underlying converters must define how they treat it (often yielding a +// ErrNilValue for pointer types or encoding a zero-value for value types). +func (w *Wrapper[T]) ToAttributeValue(value any, opts []string) (types.AttributeValue, error) { + if v, ok := value.(T); ok || value == nil { + return w.Impl.ToAttributeValue(v, opts) + } + + return nil, unsupportedType(value, *new(T)) +} diff --git a/feature/dynamodb/entitymanager/converters/wrapper_test.go b/feature/dynamodb/entitymanager/converters/wrapper_test.go new file mode 100644 index 00000000000..9652bfa2412 --- /dev/null +++ b/feature/dynamodb/entitymanager/converters/wrapper_test.go @@ -0,0 +1,77 @@ +package converters + +import ( + "reflect" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +var _ AttributeConverter[any] = (*mockConverter)(nil) + +type mockConverter struct { + retFrom any + retTo types.AttributeValue + fromCalled bool + retFromErr error + retToErr error + toCalled bool +} + +func (d *mockConverter) FromAttributeValue(_ types.AttributeValue, _ []string) (any, error) { + d.fromCalled = true + + return d.retFrom, d.retFromErr +} + +func (d *mockConverter) ToAttributeValue(_ any, _ []string) (types.AttributeValue, error) { + d.toCalled = true + + return d.retTo, d.retToErr +} + +func TestWrapper_FromAttributeValue(t *testing.T) { + d := &mockConverter{ + retFrom: 0, + retFromErr: ErrNilValue, + } + dw := &Wrapper[any]{Impl: d} + + actualOutput, actualError := dw.FromAttributeValue(nil, nil) + + comparisons := [][]any{ + {d.retFrom, actualOutput}, + {d.retFromErr, actualError}, + {d.fromCalled, true}, + } + + for _, cmp := range comparisons { + if !reflect.DeepEqual(cmp[0], cmp[1]) { + t.Fatalf("%#+v != %#+v", cmp[0], cmp[1]) + } + } +} + +func TestWrapper_ToAttributeValue(t *testing.T) { + d := &mockConverter{ + retTo: &types.AttributeValueMemberS{ + Value: "test", + }, + retToErr: ErrNilValue, + } + dw := &Wrapper[any]{Impl: d} + + actualOutput, actualError := dw.ToAttributeValue(nil, nil) + + comparisons := [][]any{ + {d.retTo, actualOutput}, + {d.retToErr, actualError}, + {d.toCalled, true}, + } + + for _, cmp := range comparisons { + if !reflect.DeepEqual(cmp[0], cmp[1]) { + t.Logf("%#+v != %#+v", cmp[0], cmp[1]) + } + } +} diff --git a/feature/dynamodb/entitymanager/decode.go b/feature/dynamodb/entitymanager/decode.go new file mode 100644 index 00000000000..4b98643f0c2 --- /dev/null +++ b/feature/dynamodb/entitymanager/decode.go @@ -0,0 +1,1075 @@ +package entitymanager + +import ( + "encoding" + "errors" + "fmt" + "reflect" + "strconv" + "time" + + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager/converters" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// An Unmarshaler is an interface to provide custom unmarshaling of +// AttributeValues. Use this to provide custom logic determining +// how AttributeValues should be unmarshaled. +// +// type ExampleUnmarshaler struct { +// Value int +// } +// +// func (u *ExampleUnmarshaler) UnmarshalDynamoDBAttributeValue(av types.AttributeValue) error { +// avN, ok := av.(*types.AttributeValueMemberN) +// if !ok { +// return nil +// } +// +// n, err := strconv.ParseInt(avN.Value, 10, 0) +// if err != nil { +// return err +// } +// +// u.Value = int(n) +// return nil +// } +type Unmarshaler interface { + UnmarshalDynamoDBAttributeValue(types.AttributeValue) error +} + +// Unmarshal will unmarshal AttributeValues to Go value types. +// Both generic interface{} and concrete types are valid unmarshal +// destination types. +// +// Unmarshal will allocate maps, slices, and pointers as needed to +// unmarshal the AttributeValue into the provided type value. +// +// When unmarshaling AttributeValues into structs Unmarshal matches +// the Field names of the struct to the AttributeValue Map keys. +// Initially it will look for exact Field name matching, but will +// fall back to case insensitive if not exact match is found. +// +// With the exception of omitempty, omitemptyelem, binaryset, numberset +// and stringset all struct tags used by Marshal are also used by +// Unmarshal. +// +// When decoding AttributeValues to interfaces Unmarshal will use the +// following types. +// +// []byte, AV Binary (B) +// [][]byte, AV Binary Set (BS) +// bool, AV Boolean (BOOL) +// []interface{}, AV List (L) +// map[string]interface{}, AV Map (M) +// float64, AV Number (N) +// Number, AV Number (N) with UseNumber set +// []float64, AV Number Set (NS) +// []Number, AV Number Set (NS) with UseNumber set +// string, AV String (S) +// []string, AV String Set (SS) +// +// If the Decoder option, UseNumber is set numbers will be unmarshaled +// as Number values instead of float64. Use this to maintain the original +// string formating of the number as it was represented in the AttributeValue. +// In addition provides additional opportunities to parse the number +// string based on individual use cases. +// +// When unmarshaling any error that occurs will halt the unmarshal +// and return the error. +// +// The expected value provided must be a non-nil pointer +func Unmarshal[T any](av types.AttributeValue, out interface{}) error { + return NewDecoder[T]().Decode(av, out) +} + +// UnmarshalWithOptions will unmarshal AttributeValues to Go value types. +// Both generic interface{} and concrete types are valid unmarshal +// destination types. +// +// Use the `optsFns` functional options to override the default configuration. +// +// UnmarshalWithOptions will allocate maps, slices, and pointers as needed to +// unmarshal the AttributeValue into the provided type value. +// +// When unmarshaling AttributeValues into structs Unmarshal matches +// the Field names of the struct to the AttributeValue Map keys. +// Initially it will look for exact Field name matching, but will +// fall back to case insensitive if not exact match is found. +// +// With the exception of omitempty, omitemptyelem, binaryset, numberset +// and stringset all struct tags used by Marshal are also used by +// UnmarshalWithOptions. +// +// When decoding AttributeValues to interfaces Unmarshal will use the +// following types. +// +// []byte, AV Binary (B) +// [][]byte, AV Binary Set (BS) +// bool, AV Boolean (BOOL) +// []interface{}, AV List (L) +// map[string]interface{}, AV Map (M) +// float64, AV Number (N) +// Number, AV Number (N) with UseNumber set +// []float64, AV Number Set (NS) +// []Number, AV Number Set (NS) with UseNumber set +// string, AV String (S) +// []string, AV String Set (SS) +// +// If the Decoder option, UseNumber is set numbers will be unmarshaled +// as Number values instead of float64. Use this to maintain the original +// string formating of the number as it was represented in the AttributeValue. +// In addition provides additional opportunities to parse the number +// string based on individual use cases. +// +// When unmarshaling any error that occurs will halt the unmarshal +// and return the error. +// +// The expected value provided must be a non-nil pointer +func UnmarshalWithOptions[T any](av types.AttributeValue, out T, optFns ...func(options *DecoderOptions)) error { + return NewDecoder[T](optFns...).Decode(av, out) +} + +// UnmarshalMap is an alias for Unmarshal which unmarshals from +// a map of AttributeValues. +// +// The expected value provided must be a non-nil pointer +func UnmarshalMap[T any](m map[string]types.AttributeValue, out T) error { + return NewDecoder[T]().Decode(&types.AttributeValueMemberM{Value: m}, out) +} + +// UnmarshalMapWithOptions is an alias for UnmarshalWithOptions which unmarshals from +// a map of AttributeValues. +// +// Use the `optsFns` functional options to override the default configuration. +// +// The expected value provided must be a non-nil pointer +func UnmarshalMapWithOptions[T any](m map[string]types.AttributeValue, out T, optFns ...func(options *DecoderOptions)) error { + return NewDecoder[T](optFns...).Decode(&types.AttributeValueMemberM{Value: m}, out) +} + +// UnmarshalList is an alias for Unmarshal func which unmarshals +// a slice of AttributeValues. +// +// The expected value provided must be a non-nil pointer +func UnmarshalList[T any](l []types.AttributeValue, out T) error { + return NewDecoder[T]().Decode(&types.AttributeValueMemberL{Value: l}, out) +} + +// UnmarshalListWithOptions is an alias for UnmarshalWithOptions func which unmarshals +// a slice of AttributeValues. +// +// Use the `optsFns` functional options to override the default configuration. +// +// The expected value provided must be a non-nil pointer +func UnmarshalListWithOptions[T any](l []types.AttributeValue, out T, optFns ...func(options *DecoderOptions)) error { + return NewDecoder[T](optFns...).Decode(&types.AttributeValueMemberL{Value: l}, out) +} + +// UnmarshalListOfMaps is an alias for Unmarshal func which unmarshals a +// slice of maps of attribute values. +// +// This is useful for when you need to unmarshal the Items from a Query API +// call. +// +// The expected value provided must be a non-nil pointer +func UnmarshalListOfMaps[T any](l []map[string]types.AttributeValue, out T) error { + items := make([]types.AttributeValue, len(l)) + for i, m := range l { + items[i] = &types.AttributeValueMemberM{Value: m} + } + + return UnmarshalList(items, out) +} + +// UnmarshalListOfMapsWithOptions is an alias for UnmarshalWithOptions func which unmarshals a +// slice of maps of attribute values. +// +// Use the `optsFns` functional options to override the default configuration. +// +// This is useful for when you need to unmarshal the Items from a Query API +// call. +// +// The expected value provided must be a non-nil pointer +func UnmarshalListOfMapsWithOptions(l []map[string]types.AttributeValue, out interface{}, optFns ...func(options *DecoderOptions)) error { + items := make([]types.AttributeValue, len(l)) + for i, m := range l { + items[i] = &types.AttributeValueMemberM{Value: m} + } + + return UnmarshalListWithOptions[any](items, out, optFns...) +} + +// DecodeTimeAttributes is the set of time decoding functions for different AttributeValues. +type DecodeTimeAttributes struct { + // Will decode S attribute values and SS attribute value elements into time.Time + // + // Default string parsing format is time.RFC3339 + S func(string) (time.Time, error) + // Will decode N attribute values and NS attribute value elements into time.Time + // + // Default number parsing format is seconds since January 1, 1970 UTC + N func(string) (time.Time, error) +} + +// DecoderOptions is a collection of options to configure how the decoder +// unmarshals the value. +type DecoderOptions struct { + // Support other custom struct tag keys, such as `yaml`, `json`, or `toml`. + // Note that values provided with a custom TagKey must also be supported + // by the (un)marshalers in this package. + // + // Tag key `dynamodbav` will always be read, but if custom tag key + // conflicts with `dynamodbav` the custom tag key value will be used. + TagKey string + + // Instructs the decoder to decode AttributeValue Numbers as + // Number type instead of float64 when the destination type + // is interface{}. Similar to encoding/json.Number + UseNumber bool + + // Contains the time decoding functions for different AttributeValues + // + // Default string parsing format is time.RFC3339 + // Default number parsing format is seconds since January 1, 1970 UTC + DecodeTime DecodeTimeAttributes + + // When enabled, the decoder will use implementations of + // encoding.TextUnmarshaler and encoding.BinaryUnmarshaler when present on + // unmarshaling targets. + // + // If a target implements [Unmarshaler], encoding unmarshaler + // implementations are ignored. + // + // If the attributevalue is a string, its underlying value will be used to + // call UnmarshalText on the target. If the attributevalue is a binary, its + // value will be used to call UnmarshalBinary. + UseEncodingUnmarshalers bool + + // IgnoreNilValueErrors controls whether decoding should ignore errors + // caused by nil values during schema conversion. + // If true, fields with nil values that cause conversion errors will be skipped. + // If false or nil, such cases will trigger an error. + IgnoreNilValueErrors *bool + + // ConverterRegistry provides a registry of type converters used during + // encoding and decoding operations. It will be set on both the Decoder + // and Encoder to control how values are transformed between Go types + // and schema representations. + ConverterRegistry *converters.Registry +} + +// A Decoder provides unmarshaling AttributeValues to Go value types. +type Decoder[T any] struct { + options DecoderOptions +} + +// NewDecoder creates a new Decoder with default configuration. Use +// the `opts` functional options to override the default configuration. +func NewDecoder[T any](optFns ...func(*DecoderOptions)) *Decoder[T] { + options := DecoderOptions{ + TagKey: defaultTagKey, + DecodeTime: DecodeTimeAttributes{ + S: defaultDecodeTimeS, + N: defaultDecodeTimeN, + }, + } + for _, fn := range optFns { + fn(&options) + } + + if options.DecodeTime.S == nil { + options.DecodeTime.S = defaultDecodeTimeS + } + + if options.DecodeTime.N == nil { + options.DecodeTime.N = defaultDecodeTimeN + } + + return &Decoder[T]{ + options: options, + } +} + +// Decode will unmarshal an AttributeValue into a Go value type. An error +// will be return if the decoder is unable to unmarshal the AttributeValue +// to the provide Go value type. +// +// The expected value provided must be a non-nil pointer +func (d *Decoder[T]) Decode(av types.AttributeValue, out interface{}, opts ...func(*Decoder[T])) error { + v := reflect.ValueOf(out) + if v.Kind() != reflect.Ptr || v.IsNil() || !v.IsValid() { + return &InvalidUnmarshalError{Type: reflect.TypeOf(out)} + } + + return d.decode(av, v, Tag{}) +} + +var stringInterfaceMapType = reflect.TypeOf(map[string]interface{}(nil)) +var byteSliceType = reflect.TypeOf([]byte(nil)) +var byteSliceSliceType = reflect.TypeOf([][]byte(nil)) +var timeType = reflect.TypeOf(time.Time{}) + +func (d *Decoder[T]) decode(av types.AttributeValue, v reflect.Value, fieldTag Tag) error { + var u Unmarshaler + _, isNull := av.(*types.AttributeValueMemberNULL) + if av == nil || isNull { + u, v = indirect[Unmarshaler](v, indirectOptions{decodeNull: true}) + if u != nil { + return u.UnmarshalDynamoDBAttributeValue(av) + } + return d.decodeNull(v) + } + + if d.options.ConverterRegistry != nil && fieldTag.Converter { + el := valueElem(v) + cvtName := el.Type().String() + + opts, ok := fieldTag.Option("converter") + if ok { + cvtName = opts[0] + } + + if cvt := d.options.ConverterRegistry.Converter(cvtName); cvt != nil { + vr, err := cvt.FromAttributeValue(av, opts) + + if errors.Is(converters.ErrNilValue, err) && !unwrap(d.options.IgnoreNilValueErrors) { + err = nil + } + + if err != nil { + return err + } + + rv := reflect.ValueOf(vr) + el.Set(rv) + + return nil + } + } + + v0 := v + u, v = indirect[Unmarshaler](v, indirectOptions{}) + if u != nil { + return u.UnmarshalDynamoDBAttributeValue(av) + } + if d.options.UseEncodingUnmarshalers { + if s, ok := av.(*types.AttributeValueMemberS); ok { + if u, _ := indirect[encoding.TextUnmarshaler](v0, indirectOptions{}); u != nil { + return u.UnmarshalText([]byte(s.Value)) + } + } + if b, ok := av.(*types.AttributeValueMemberB); ok { + if u, _ := indirect[encoding.BinaryUnmarshaler](v0, indirectOptions{}); u != nil { + return u.UnmarshalBinary(b.Value) + } + } + } + + switch tv := av.(type) { + case *types.AttributeValueMemberB: + return d.decodeBinary(tv.Value, v) + + case *types.AttributeValueMemberBOOL: + return d.decodeBool(tv.Value, v) + + case *types.AttributeValueMemberBS: + return d.decodeBinarySet(tv.Value, v) + + case *types.AttributeValueMemberL: + return d.decodeList(tv.Value, v) + + case *types.AttributeValueMemberM: + return d.decodeMap(tv.Value, v) + + case *types.AttributeValueMemberN: + return d.decodeNumber(tv.Value, v, fieldTag) + + case *types.AttributeValueMemberNS: + return d.decodeNumberSet(tv.Value, v) + + case *types.AttributeValueMemberS: + return d.decodeString(tv.Value, v, fieldTag) + + case *types.AttributeValueMemberSS: + return d.decodeStringSet(tv.Value, v) + + default: + return fmt.Errorf("unsupported AttributeValue type, %V", av) + } +} + +func (d *Decoder[T]) decodeBinary(b []byte, v reflect.Value) error { + if v.Kind() == reflect.Interface { + buf := make([]byte, len(b)) + copy(buf, b) + v.Set(reflect.ValueOf(buf)) + return nil + } + + if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { + return &UnmarshalTypeError{Value: "binary", Type: v.Type()} + } + + if v.Type() == byteSliceType { + // Optimization for []byte types + if v.IsNil() || v.Cap() < len(b) { + v.Set(reflect.MakeSlice(byteSliceType, len(b), len(b))) + } else if v.Len() != len(b) { + v.SetLen(len(b)) + } + copy(v.Interface().([]byte), b) + return nil + } + + switch v.Type().Elem().Kind() { + case reflect.Uint8: + // Fallback to reflection copy for type aliased of []byte type + if v.Kind() != reflect.Array && (v.IsNil() || v.Cap() < len(b)) { + v.Set(reflect.MakeSlice(v.Type(), len(b), len(b))) + } else if v.Len() != len(b) { + v.SetLen(len(b)) + } + for i := 0; i < len(b); i++ { + v.Index(i).SetUint(uint64(b[i])) + } + default: + if v.Kind() == reflect.Array && v.Type().Elem().Kind() == reflect.Uint8 { + reflect.Copy(v, reflect.ValueOf(b)) + break + } + return &UnmarshalTypeError{Value: "binary", Type: v.Type()} + } + + return nil +} + +func (d *Decoder[T]) decodeBool(b bool, v reflect.Value) error { + switch v.Kind() { + case reflect.Bool, reflect.Interface: + v.Set(reflect.ValueOf(b).Convert(v.Type())) + + default: + return &UnmarshalTypeError{Value: "bool", Type: v.Type()} + } + + return nil +} + +func (d *Decoder[T]) decodeBinarySet(bs [][]byte, v reflect.Value) error { + var isArray bool + + switch v.Kind() { + case reflect.Slice: + // Make room for the slice elements if needed + if v.IsNil() || v.Cap() < len(bs) { + // What about if ignoring nil/empty values? + v.Set(reflect.MakeSlice(v.Type(), 0, len(bs))) + } + case reflect.Array: + // Limited to capacity of existing array. + isArray = true + case reflect.Interface: + set := make([][]byte, len(bs)) + for i, b := range bs { + if err := d.decodeBinary(b, reflect.ValueOf(&set[i]).Elem()); err != nil { + return err + } + } + v.Set(reflect.ValueOf(set)) + return nil + default: + return &UnmarshalTypeError{Value: "binary set", Type: v.Type()} + } + + for i := 0; i < v.Cap() && i < len(bs); i++ { + if !isArray { + v.SetLen(i + 1) + } + u, elem := indirect[Unmarshaler](v.Index(i), indirectOptions{}) + if u != nil { + err := u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberB{Value: bs[i]}) + if err != nil { + return err + } + continue + } + if err := d.decodeBinary(bs[i], elem); err != nil { + return err + } + } + + return nil +} + +func (d *Decoder[T]) decodeNumber(n string, v reflect.Value, fieldTag Tag) error { + switch v.Kind() { + case reflect.Interface: + i, err := d.decodeNumberToInterface(n) + if err != nil { + return err + } + v.Set(reflect.ValueOf(i)) + return nil + case reflect.String: + if isNumberValueType(v) { + v.SetString(n) + return nil + } + v.SetString(n) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := strconv.ParseInt(n, 10, 64) + if err != nil { + return err + } + if v.OverflowInt(i) { + return &UnmarshalTypeError{ + Value: fmt.Sprintf("number overflow, %s", n), + Type: v.Type(), + } + } + v.SetInt(i) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i, err := strconv.ParseUint(n, 10, 64) + if err != nil { + return err + } + if v.OverflowUint(i) { + return &UnmarshalTypeError{ + Value: fmt.Sprintf("number overflow, %s", n), + Type: v.Type(), + } + } + v.SetUint(i) + case reflect.Float32, reflect.Float64: + i, err := strconv.ParseFloat(n, 64) + if err != nil { + return err + } + if v.OverflowFloat(i) { + return &UnmarshalTypeError{ + Value: fmt.Sprintf("number overflow, %s", n), + Type: v.Type(), + } + } + v.SetFloat(i) + default: + if v.Type().ConvertibleTo(timeType) && fieldTag.AsUnixTime { + t, err := decodeUnixTime(n) + if err != nil { + return err + } + v.Set(reflect.ValueOf(t).Convert(v.Type())) + return nil + } + if v.Type().ConvertibleTo(timeType) { + t, err := d.options.DecodeTime.N(n) + if err != nil { + return err + } + v.Set(reflect.ValueOf(t).Convert(v.Type())) + return nil + } + return &UnmarshalTypeError{Value: "number", Type: v.Type()} + } + + return nil +} + +func (d *Decoder[T]) decodeNumberToInterface(n string) (interface{}, error) { + if d.options.UseNumber { + return Number(n), nil + } + + // Default to float64 for all numbers + return strconv.ParseFloat(n, 64) +} + +func (d *Decoder[T]) decodeNumberSet(ns []string, v reflect.Value) error { + var isArray bool + + switch v.Kind() { + case reflect.Slice: + // Make room for the slice elements if needed + if v.IsNil() || v.Cap() < len(ns) { + // What about if ignoring nil/empty values? + v.Set(reflect.MakeSlice(v.Type(), 0, len(ns))) + } + case reflect.Array: + // Limited to capacity of existing array. + isArray = true + case reflect.Interface: + if d.options.UseNumber { + set := make([]Number, len(ns)) + for i, n := range ns { + if err := d.decodeNumber(n, reflect.ValueOf(&set[i]).Elem(), Tag{}); err != nil { + return err + } + } + v.Set(reflect.ValueOf(set)) + } else { + set := make([]float64, len(ns)) + for i, n := range ns { + if err := d.decodeNumber(n, reflect.ValueOf(&set[i]).Elem(), Tag{}); err != nil { + return err + } + } + v.Set(reflect.ValueOf(set)) + } + return nil + default: + return &UnmarshalTypeError{Value: "number set", Type: v.Type()} + } + + for i := 0; i < v.Cap() && i < len(ns); i++ { + if !isArray { + v.SetLen(i + 1) + } + u, elem := indirect[Unmarshaler](v.Index(i), indirectOptions{}) + if u != nil { + err := u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberN{Value: ns[i]}) + if err != nil { + return err + } + continue + } + if err := d.decodeNumber(ns[i], elem, Tag{}); err != nil { + return err + } + } + + return nil +} + +func (d *Decoder[T]) decodeList(avList []types.AttributeValue, v reflect.Value) error { + var isArray bool + + switch v.Kind() { + case reflect.Slice: + // Make room for the slice elements if needed + if v.IsNil() || v.Cap() < len(avList) { + // What about if ignoring nil/empty values? + v.Set(reflect.MakeSlice(v.Type(), 0, len(avList))) + } + case reflect.Array: + // Limited to capacity of existing array. + isArray = true + case reflect.Interface: + s := make([]interface{}, len(avList)) + for i, av := range avList { + if err := d.decode(av, reflect.ValueOf(&s[i]).Elem(), Tag{}); err != nil { + return err + } + } + v.Set(reflect.ValueOf(s)) + return nil + default: + return &UnmarshalTypeError{Value: "list", Type: v.Type()} + } + + // If v is not a slice, array + for i := 0; i < v.Cap() && i < len(avList); i++ { + if !isArray { + v.SetLen(i + 1) + } + if err := d.decode(avList[i], v.Index(i), Tag{}); err != nil { + return err + } + } + + return nil +} + +func (d *Decoder[T]) decodeMap(avMap map[string]types.AttributeValue, v reflect.Value) (err error) { + var decodeMapKey func(v string, key reflect.Value, fieldTag Tag) error + + switch v.Kind() { + case reflect.Map: + decodeMapKey, err = d.getMapKeyDecoder(v.Type().Key()) + if err != nil { + return err + } + + if v.IsNil() { + v.Set(reflect.MakeMap(v.Type())) + } + case reflect.Struct: + case reflect.Interface: + v.Set(reflect.MakeMap(stringInterfaceMapType)) + decodeMapKey = d.decodeString + v = v.Elem() + default: + return &UnmarshalTypeError{Value: "map", Type: v.Type()} + } + + if v.Kind() == reflect.Map { + keyType := v.Type().Key() + valueType := v.Type().Elem() + for k, av := range avMap { + key := reflect.New(keyType).Elem() + // handle pointer keys + _, indirectKey := indirect[Unmarshaler](key, indirectOptions{skipUnmarshaler: true}) + if err := decodeMapKey(k, indirectKey, Tag{}); err != nil { + return &UnmarshalTypeError{ + Value: fmt.Sprintf("map key %q", k), + Type: keyType, + Err: err, + } + } + + elem := reflect.New(valueType).Elem() + if err := d.decode(av, elem, Tag{}); err != nil { + return err + } + + v.SetMapIndex(key, elem) + } + } else if v.Kind() == reflect.Struct { + fields := unionStructFields(v.Type(), structFieldOptions{ + TagKey: d.options.TagKey, + }) + for k, av := range avMap { + if f, ok := fields.FieldByName(k); ok { + fv := decoderFieldByIndex(v, f.Index) + if err := d.decode(av, fv, f.Tag); err != nil { + return err + } + } + } + } + + return nil +} + +var numberType = reflect.TypeOf(Number("")) +var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + +func (d *Decoder[T]) getMapKeyDecoder(keyType reflect.Type) (func(string, reflect.Value, Tag) error, error) { + // Test the key type to determine if it implements the TextUnmarshaler interface. + if reflect.PtrTo(keyType).Implements(textUnmarshalerType) || keyType.Implements(textUnmarshalerType) { + return func(v string, k reflect.Value, _ Tag) error { + if !k.CanAddr() { + return fmt.Errorf("cannot take address of map key, %v", k.Type()) + } + return k.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(v)) + }, nil + } + + var decodeMapKey func(v string, key reflect.Value, fieldTag Tag) error + + switch keyType.Kind() { + case reflect.Bool: + decodeMapKey = func(v string, key reflect.Value, fieldTag Tag) error { + b, err := strconv.ParseBool(v) + if err != nil { + return err + } + return d.decodeBool(b, key) + } + case reflect.String: + // Number type handled as a string + decodeMapKey = d.decodeString + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + decodeMapKey = d.decodeNumber + + default: + return nil, &UnmarshalTypeError{ + Value: "map key must be string, number, bool, or TextUnmarshaler", + Type: keyType, + } + } + + return decodeMapKey, nil +} + +func (d *Decoder[T]) decodeNull(v reflect.Value) error { + if v.IsValid() && v.CanSet() { + v.Set(reflect.Zero(v.Type())) + } + + return nil +} + +func (d *Decoder[T]) decodeString(s string, v reflect.Value, fieldTag Tag) error { + if fieldTag.AsString { + return d.decodeNumber(s, v, fieldTag) + } + + // To maintain backwards compatibility with ConvertFrom family of methods which + // converted strings to time.Time structs + if v.Type().ConvertibleTo(timeType) { + t, err := d.options.DecodeTime.S(s) + if err != nil { + return err + } + v.Set(reflect.ValueOf(t).Convert(v.Type())) + return nil + } + + switch v.Kind() { + case reflect.String: + v.SetString(s) + case reflect.Interface: + // Ensure type aliasing is handled properly + v.Set(reflect.ValueOf(s).Convert(v.Type())) + default: + return &UnmarshalTypeError{Value: "string", Type: v.Type()} + } + + return nil +} + +func (d *Decoder[T]) decodeStringSet(ss []string, v reflect.Value) error { + var isArray bool + + switch v.Kind() { + case reflect.Slice: + // Make room for the slice elements if needed + if v.IsNil() || v.Cap() < len(ss) { + v.Set(reflect.MakeSlice(v.Type(), 0, len(ss))) + } + case reflect.Array: + // Limited to capacity of existing array. + isArray = true + case reflect.Interface: + set := make([]string, len(ss)) + for i, s := range ss { + if err := d.decodeString(s, reflect.ValueOf(&set[i]).Elem(), Tag{}); err != nil { + return err + } + } + v.Set(reflect.ValueOf(set)) + return nil + default: + return &UnmarshalTypeError{Value: "string set", Type: v.Type()} + } + + for i := 0; i < v.Cap() && i < len(ss); i++ { + if !isArray { + v.SetLen(i + 1) + } + u, elem := indirect[Unmarshaler](v.Index(i), indirectOptions{}) + if u != nil { + err := u.UnmarshalDynamoDBAttributeValue(&types.AttributeValueMemberS{Value: ss[i]}) + if err != nil { + return err + } + continue + } + if err := d.decodeString(ss[i], elem, Tag{}); err != nil { + return err + } + } + + return nil +} + +func decodeUnixTime(n string) (time.Time, error) { + v, err := strconv.ParseInt(n, 10, 64) + if err != nil { + return time.Time{}, &UnmarshalError{ + Err: err, Value: n, Type: timeType, + } + } + + return time.Unix(v, 0), nil +} + +// decoderFieldByIndex finds the Field with the provided nested index, allocating +// embedded parent structs if needed +func decoderFieldByIndex(v reflect.Value, index []int) reflect.Value { + for i, x := range index { + if i > 0 && v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + } + v = v.Field(x) + } + return v +} + +type indirectOptions struct { + decodeNull bool + skipUnmarshaler bool +} + +// indirect will walk a value's interface or pointer value types. Returning +// the final value or the value a unmarshaler is defined on. +// +// Based on the enoding/json type reflect value type indirection in Go Stdlib +// https://golang.org/src/encoding/json/decode.go indirect func. +func indirect[U any](v reflect.Value, opts indirectOptions) (U, reflect.Value) { + // Issue #24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true + v = v.Addr() + } + + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Ptr && !e.IsNil() && (!opts.decodeNull || e.Elem().Kind() == reflect.Ptr) { + haveAddr = false + v = e + continue + } + if e.Kind() != reflect.Ptr && e.IsValid() { + var u U + return u, e + } + } + if v.Kind() != reflect.Ptr { + break + } + if opts.decodeNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v interface{} + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v { + v = v.Elem() + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if !opts.skipUnmarshaler && v.Type().NumMethod() > 0 && v.CanInterface() { + if u, ok := v.Interface().(U); ok { + return u, reflect.Value{} + } + } + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } + } + + var u U + return u, v +} + +// A Number represents a Attributevalue number literal. +type Number string + +// Float64 attempts to cast the number to a float64, returning +// the result of the case or error if the case failed. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 attempts to cast the number to a int64, returning +// the result of the case or error if the case failed. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +} + +// Uint64 attempts to cast the number to a uint64, returning +// the result of the case or error if the case failed. +func (n Number) Uint64() (uint64, error) { + return strconv.ParseUint(string(n), 10, 64) +} + +// String returns the raw number represented as a string +func (n Number) String() string { + return string(n) +} + +// An UnmarshalTypeError is an error type representing a error +// unmarshaling the AttributeValue's element to a Go value type. +// Includes details about the AttributeValue type and Go value type. +type UnmarshalTypeError struct { + Value string + Type reflect.Type + Err error +} + +// Unwrap returns the underlying error if any. +func (e *UnmarshalTypeError) Unwrap() error { return e.Err } + +// Error returns the string representation of the error. +// satisfying the error interface +func (e *UnmarshalTypeError) Error() string { + return fmt.Sprintf("unmarshal failed, cannot unmarshal %s into Go value type %s", + e.Value, e.Type.String()) +} + +// An InvalidUnmarshalError is an error type representing an invalid type +// encountered while unmarshaling a AttributeValue to a Go value type. +type InvalidUnmarshalError struct { + Type reflect.Type +} + +// Error returns the string representation of the error. +// satisfying the error interface +func (e *InvalidUnmarshalError) Error() string { + var msg string + if e.Type == nil { + msg = "cannot unmarshal to nil value" + } else if e.Type.Kind() != reflect.Ptr { + msg = fmt.Sprintf("cannot unmarshal to non-pointer value, got %s", e.Type.String()) + } else { + msg = fmt.Sprintf("cannot unmarshal to nil value, %s", e.Type.String()) + } + + return fmt.Sprintf("unmarshal failed, %s", msg) +} + +// An UnmarshalError wraps an error that occurred while unmarshaling a +// AttributeValue element into a Go type. This is different from +// UnmarshalTypeError in that it wraps the underlying error that occurred. +type UnmarshalError struct { + Err error + Value string + Type reflect.Type +} + +func (e *UnmarshalError) Unwrap() error { + return e.Err +} + +// Error returns the string representation of the error satisfying the error +// interface. +func (e *UnmarshalError) Error() string { + return fmt.Sprintf("unmarshal failed, cannot unmarshal %q into %s, %v", + e.Value, e.Type.String(), e.Err) +} + +func defaultDecodeTimeS(v string) (time.Time, error) { + t, err := time.Parse(time.RFC3339, v) + if err != nil { + return time.Time{}, &UnmarshalError{Err: err, Value: v, Type: timeType} + } + return t, nil +} + +func defaultDecodeTimeN(v string) (time.Time, error) { + return decodeUnixTime(v) +} diff --git a/feature/dynamodb/entitymanager/decode_test.go b/feature/dynamodb/entitymanager/decode_test.go new file mode 100644 index 00000000000..adb5d436930 --- /dev/null +++ b/feature/dynamodb/entitymanager/decode_test.go @@ -0,0 +1,1547 @@ +package entitymanager + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager/converters" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestUnmarshalShared(t *testing.T) { + for name, c := range sharedTestCases { + t.Run(name, func(t *testing.T) { + err := Unmarshal[any](c.in, c.actual) + assertConvertTest(t, c.actual, c.expected, err, c.err) + }) + } +} + +func TestUnmarshal(t *testing.T) { + cases := []struct { + in types.AttributeValue + actual, expected interface{} + err error + }{ + //------------ + // Sets + //------------ + { + in: &types.AttributeValueMemberBS{Value: [][]byte{ + {48, 49}, {50, 51}, + }}, + actual: &[][]byte{}, + expected: [][]byte{{48, 49}, {50, 51}}, + }, + { + in: &types.AttributeValueMemberNS{Value: []string{ + "123", "321", + }}, + actual: &[]int{}, + expected: []int{123, 321}, + }, + { + in: &types.AttributeValueMemberNS{Value: []string{ + "123", "321", + }}, + actual: &[]interface{}{}, + expected: []interface{}{123., 321.}, + }, + { + in: &types.AttributeValueMemberSS{Value: []string{ + "abc", "123", + }}, + actual: &[]string{}, + expected: &[]string{"abc", "123"}, + }, + { + in: &types.AttributeValueMemberSS{Value: []string{ + "abc", "123", + }}, + actual: &[]*string{}, + expected: &[]*string{pointer("abc"), pointer("123")}, + }, + //------------ + // Interfaces + //------------ + { + in: &types.AttributeValueMemberB{Value: []byte{48, 49}}, + actual: func() interface{} { + var v interface{} + return &v + }(), + expected: []byte{48, 49}, + }, + { + in: &types.AttributeValueMemberBS{Value: [][]byte{ + {48, 49}, {50, 51}, + }}, + actual: func() interface{} { + var v interface{} + return &v + }(), + expected: [][]byte{{48, 49}, {50, 51}}, + }, + { + in: &types.AttributeValueMemberBOOL{Value: true}, + actual: func() interface{} { + var v interface{} + return &v + }(), + expected: bool(true), + }, + { + in: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "abc"}, + &types.AttributeValueMemberS{Value: "123"}, + }}, + actual: func() interface{} { + var v interface{} + return &v + }(), + expected: []interface{}{"abc", "123"}, + }, + { + in: &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "123": &types.AttributeValueMemberS{Value: "abc"}, + "abc": &types.AttributeValueMemberS{Value: "123"}, + }}, + actual: func() interface{} { + var v interface{} + return &v + }(), + expected: map[string]interface{}{"123": "abc", "abc": "123"}, + }, + { + in: &types.AttributeValueMemberN{Value: "123"}, + actual: func() interface{} { + var v interface{} + return &v + }(), + expected: float64(123), + }, + { + in: &types.AttributeValueMemberNS{Value: []string{ + "123", "321", + }}, + actual: func() interface{} { + var v interface{} + return &v + }(), + expected: []float64{123., 321.}, + }, + { + in: &types.AttributeValueMemberS{Value: "123"}, + actual: func() interface{} { + var v interface{} + return &v + }(), + expected: "123", + }, + { + in: &types.AttributeValueMemberNULL{Value: true}, + actual: func() interface{} { + var v string + return &v + }(), + expected: "", + }, + { + in: &types.AttributeValueMemberNULL{Value: true}, + actual: func() interface{} { + v := new(string) + return &v + }(), + expected: nil, + }, + { + in: &types.AttributeValueMemberS{Value: ""}, + actual: func() interface{} { + v := new(string) + return &v + }(), + expected: pointer(""), + }, + { + in: &types.AttributeValueMemberSS{Value: []string{ + "123", "321", + }}, + actual: func() interface{} { + var v interface{} + return &v + }(), + expected: []string{"123", "321"}, + }, + { + in: &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "abc": &types.AttributeValueMemberS{Value: "123"}, + "Cba": &types.AttributeValueMemberS{Value: "321"}, + }}, + actual: &struct{ Abc, Cba string }{}, + expected: struct{ Abc, Cba string }{Abc: "123", Cba: "321"}, + }, + { + in: &types.AttributeValueMemberN{Value: "512"}, + actual: new(uint8), + err: &UnmarshalTypeError{ + Value: fmt.Sprintf("number overflow, 512"), + Type: reflect.TypeOf(uint8(0)), + }, + }, + // ------- + // Empty Values + // ------- + { + in: &types.AttributeValueMemberB{Value: []byte{}}, + actual: &[]byte{}, + expected: []byte{}, + }, + { + in: &types.AttributeValueMemberBS{Value: [][]byte{}}, + actual: &[][]byte{}, + expected: [][]byte{}, + }, + { + in: &types.AttributeValueMemberL{Value: []types.AttributeValue{}}, + actual: &[]interface{}{}, + expected: []interface{}{}, + }, + { + in: &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{}}, + actual: &map[string]interface{}{}, + expected: map[string]interface{}{}, + }, + { + in: &types.AttributeValueMemberN{Value: ""}, + actual: new(int), + err: fmt.Errorf("invalid syntax"), + }, + { + in: &types.AttributeValueMemberNS{Value: []string{}}, + actual: &[]string{}, + expected: []string{}, + }, + { + in: &types.AttributeValueMemberS{Value: ""}, + actual: new(string), + expected: "", + }, + { + in: &types.AttributeValueMemberSS{Value: []string{}}, + actual: &[]string{}, + expected: []string{}, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("case %d/%d", i, len(cases)), func(t *testing.T) { + err := Unmarshal[any](c.in, c.actual) + assertConvertTest(t, c.actual, c.expected, err, c.err) + }) + } +} + +func TestInterfaceInput(t *testing.T) { + var v interface{} + expected := []interface{}{"abc", "123"} + err := Unmarshal[any](&types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "abc"}, + &types.AttributeValueMemberS{Value: "123"}, + }}, &v) + assertConvertTest(t, v, expected, err, nil) +} + +func TestUnmarshalError(t *testing.T) { + cases := map[string]struct { + in types.AttributeValue + actual, expected interface{} + err error + }{ + "invalid unmarshal": { + in: nil, + actual: int(0), + expected: nil, + err: &InvalidUnmarshalError{Type: reflect.TypeOf(int(0))}, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + err := Unmarshal[any](c.in, c.actual) + assertConvertTest(t, c.actual, c.expected, err, c.err) + }) + } +} + +func TestUnmarshalListShared(t *testing.T) { + for name, c := range sharedListTestCases { + t.Run(name, func(t *testing.T) { + err := UnmarshalList(c.in, c.actual) + assertConvertTest(t, c.actual, c.expected, err, c.err) + }) + } +} + +func TestUnmarshalListError(t *testing.T) { + cases := map[string]struct { + in []types.AttributeValue + actual, expected interface{} + err error + }{ + "invalid unmarshal": { + in: []types.AttributeValue{}, + actual: []interface{}{}, + expected: nil, + err: &InvalidUnmarshalError{Type: reflect.TypeOf([]interface{}{})}, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + err := UnmarshalList(c.in, c.actual) + assertConvertTest(t, c.actual, c.expected, err, c.err) + }) + } +} + +func TestUnmarshalMapShared(t *testing.T) { + for name, c := range sharedMapTestCases { + t.Run(name, func(t *testing.T) { + err := UnmarshalMap[any](c.in, c.actual) + assertConvertTest(t, c.actual, c.expected, err, c.err) + }) + } +} + +func TestUnmarshalMapError(t *testing.T) { + cases := []struct { + in map[string]types.AttributeValue + actual, expected interface{} + err error + }{ + { + in: map[string]types.AttributeValue{}, + actual: map[string]interface{}{}, + expected: nil, + err: &InvalidUnmarshalError{Type: reflect.TypeOf(map[string]interface{}{})}, + }, + { + in: map[string]types.AttributeValue{ + "BOOL": &types.AttributeValueMemberBOOL{Value: true}, + }, + actual: &map[int]interface{}{}, + expected: nil, + err: &UnmarshalTypeError{ + Value: `map key "BOOL"`, + Type: reflect.TypeOf(int(0)), + }, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + err := UnmarshalMap[any](c.in, c.actual) + assertConvertTest(t, c.actual, c.expected, err, c.err) + }) + } +} + +func TestUnmarshalListOfMaps(t *testing.T) { + type testItem struct { + Value string + Value2 int + } + + cases := map[string]struct { + in []map[string]types.AttributeValue + actual, expected interface{} + err error + }{ + "simple map conversion": { + in: []map[string]types.AttributeValue{ + { + "Value": &types.AttributeValueMemberBOOL{Value: true}, + }, + }, + actual: &[]map[string]interface{}{}, + expected: []map[string]interface{}{ + { + "Value": true, + }, + }, + }, + "attribute to struct": { + in: []map[string]types.AttributeValue{ + { + "Value": &types.AttributeValueMemberS{Value: "abc"}, + "Value2": &types.AttributeValueMemberN{Value: "123"}, + }, + }, + actual: &[]testItem{}, + expected: []testItem{ + { + Value: "abc", + Value2: 123, + }, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + err := UnmarshalListOfMaps[any](c.in, c.actual) + assertConvertTest(t, c.actual, c.expected, err, c.err) + }) + } +} + +type unmarshalUnmarshaler struct { + Value string + Value2 int + Value3 bool + Value4 time.Time +} + +func (u *unmarshalUnmarshaler) UnmarshalDynamoDBAttributeValue(av types.AttributeValue) error { + m, ok := av.(*types.AttributeValueMemberM) + if !ok || m == nil { + return fmt.Errorf("expected AttributeValue to be map") + } + + if v, ok := m.Value["abc"]; !ok { + return fmt.Errorf("expected `abc` map key") + } else if vv, kk := v.(*types.AttributeValueMemberS); !kk || vv == nil { + return fmt.Errorf("expected `abc` map value string") + } else { + u.Value = vv.Value + } + + if v, ok := m.Value["def"]; !ok { + return fmt.Errorf("expected `def` map key") + } else if vv, kk := v.(*types.AttributeValueMemberN); !kk || vv == nil { + return fmt.Errorf("expected `def` map value number") + } else { + n, err := strconv.ParseInt(vv.Value, 10, 64) + if err != nil { + return err + } + u.Value2 = int(n) + } + + if v, ok := m.Value["ghi"]; !ok { + return fmt.Errorf("expected `ghi` map key") + } else if vv, kk := v.(*types.AttributeValueMemberBOOL); !kk || vv == nil { + return fmt.Errorf("expected `ghi` map value number") + } else { + u.Value3 = vv.Value + } + + if v, ok := m.Value["jkl"]; !ok { + return fmt.Errorf("expected `jkl` map key") + } else if vv, kk := v.(*types.AttributeValueMemberS); !kk || vv == nil { + return fmt.Errorf("expected `jkl` map value string") + } else { + t, err := time.Parse(time.RFC3339, vv.Value) + if err != nil { + return err + } + u.Value4 = t + } + + return nil +} + +func TestUnmarshalUnmashaler(t *testing.T) { + u := &unmarshalUnmarshaler{} + av := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "abc": &types.AttributeValueMemberS{Value: "value"}, + "def": &types.AttributeValueMemberN{Value: "123"}, + "ghi": &types.AttributeValueMemberBOOL{Value: true}, + "jkl": &types.AttributeValueMemberS{Value: "2016-05-03T17:06:26.209072Z"}, + }, + } + + err := Unmarshal[any](av, u) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + if e, a := "value", u.Value; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := 123, u.Value2; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := true, u.Value3; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := testDate, u.Value4; e != a { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestDecodeUseNumber(t *testing.T) { + u := map[string]interface{}{} + av := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "abc": &types.AttributeValueMemberS{Value: "value"}, + "def": &types.AttributeValueMemberN{Value: "123"}, + "ghi": &types.AttributeValueMemberBOOL{Value: true}, + }, + } + + decoder := NewDecoder[any](func(o *DecoderOptions) { + o.UseNumber = true + }) + err := decoder.Decode(av, &u) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + if e, a := "value", u["abc"]; e != a { + t.Errorf("expect %v, got %v", e, a) + } + n := u["def"].(Number) + if e, a := "123", n.String(); e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := true, u["ghi"]; e != a { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestDecodeUseNumberNumberSet(t *testing.T) { + u := map[string]interface{}{} + av := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "ns": &types.AttributeValueMemberNS{ + Value: []string{ + "123", "321", + }, + }, + }, + } + + decoder := NewDecoder[any](func(o *DecoderOptions) { + o.UseNumber = true + }) + err := decoder.Decode(av, &u) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + + ns := u["ns"].([]Number) + + if e, a := "123", ns[0].String(); e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "321", ns[1].String(); e != a { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestDecodeEmbeddedPointerStruct(t *testing.T) { + type B struct { + Bint int + } + type C struct { + Cint int + } + type A struct { + Aint int + *B + *C + } + av := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Aint": &types.AttributeValueMemberN{Value: "321"}, + "Bint": &types.AttributeValueMemberN{Value: "123"}, + }, + } + decoder := NewDecoder[any]() + a := A{} + err := decoder.Decode(av, &a) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + if e, a := 321, a.Aint; e != a { + t.Errorf("expect %v, got %v", e, a) + } + // Embedded pointer struct can be created automatically. + if e, a := 123, a.Bint; e != a { + t.Errorf("expect %v, got %v", e, a) + } + // But not for absent CachedFields. + if a.C != nil { + t.Errorf("expect nil, got %v", a.C) + } +} + +func TestDecodeBooleanOverlay(t *testing.T) { + type BooleanOverlay bool + + av := &types.AttributeValueMemberBOOL{Value: true} + + decoder := NewDecoder[any]() + + var v BooleanOverlay + + err := decoder.Decode(av, &v) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + if e, a := BooleanOverlay(true), v; e != a { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestDecodeUnixTime(t *testing.T) { + type A struct { + Normal time.Time + Tagged time.Time `dynamodbav:",unixtime"` + Typed UnixTime + } + + expect := A{ + Normal: time.Unix(123, 0).UTC(), + Tagged: time.Unix(456, 0), + Typed: UnixTime(time.Unix(789, 0)), + } + + input := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Normal": &types.AttributeValueMemberS{Value: "1970-01-01T00:02:03Z"}, + "Tagged": &types.AttributeValueMemberN{Value: "456"}, + "Typed": &types.AttributeValueMemberN{Value: "789"}, + }, + } + actual := A{} + + err := Unmarshal[any](input, &actual) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + if e, a := expect, actual; e != a { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestDecodeAliasedUnixTime(t *testing.T) { + type A struct { + Normal AliasedTime + Tagged AliasedTime `dynamodbav:",unixtime"` + } + + expect := A{ + Normal: AliasedTime(time.Unix(123, 0).UTC()), + Tagged: AliasedTime(time.Unix(456, 0)), + } + + input := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Normal": &types.AttributeValueMemberS{Value: "1970-01-01T00:02:03Z"}, + "Tagged": &types.AttributeValueMemberN{Value: "456"}, + }, + } + actual := A{} + + err := Unmarshal[any](input, &actual) + if err != nil { + t.Errorf("expect no error, got %v", err) + } + if expect != actual { + t.Errorf("expect %v, got %v", expect, actual) + } +} + +// see github issue #1594 +func TestDecodeArrayType(t *testing.T) { + cases := []struct { + to, from interface{} + }{ + { + &[2]int{1, 2}, + &[2]int{}, + }, + { + &[2]int64{1, 2}, + &[2]int64{}, + }, + { + &[2]byte{1, 2}, + &[2]byte{}, + }, + { + &[2]bool{true, false}, + &[2]bool{}, + }, + { + &[2]string{"1", "2"}, + &[2]string{}, + }, + { + &[2][]string{{"1", "2"}}, + &[2][]string{}, + }, + } + + for _, c := range cases { + marshaled, err := Marshal(c.to) + if err != nil { + t.Errorf("expected no error, but received %v", err) + } + + if err = Unmarshal[any](marshaled, c.from); err != nil { + t.Errorf("expected no error, but received %v", err) + } + + if diff := cmpDiff(c.to, c.from); len(diff) != 0 { + t.Errorf("expected match\n:%s", diff) + } + } +} + +func TestDecoderFieldByIndex(t *testing.T) { + type ( + Middle struct{ Inner int } + Outer struct{ *Middle } + ) + var outer Outer + + outerType := reflect.TypeOf(outer) + outerValue := reflect.ValueOf(&outer) + outerFields := unionStructFields(outerType, structFieldOptions{}) + innerField, _ := outerFields.FieldByName("Inner") + + f := decoderFieldByIndex(outerValue.Elem(), innerField.Index) + if outer.Middle == nil { + t.Errorf("expected outer.Middle to be non-nil") + } + if f.Kind() != reflect.Int || f.Int() != int64(outer.Inner) { + t.Error("expected f to be an int with value equal to outer.Inner") + } +} +func TestDecodeAliasType(t *testing.T) { + type Str string + type Int int + type Uint uint + type TT struct { + A Str + B Int + C Uint + S Str + } + + expect := TT{ + A: "12345", + B: 12345, + C: 12345, + S: "string", + } + m := map[string]types.AttributeValue{ + "A": &types.AttributeValueMemberN{ + Value: "12345", + }, + "B": &types.AttributeValueMemberN{ + Value: "12345", + }, + "C": &types.AttributeValueMemberN{ + Value: "12345", + }, + "S": &types.AttributeValueMemberS{ + Value: "string", + }, + } + + var actual TT + err := UnmarshalMap(m, &actual) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if !reflect.DeepEqual(expect, actual) { + t.Errorf("expect:\n%v\nactual:\n%v", expect, actual) + } +} + +type testUnmarshalMapKeyComplex struct { + Foo string +} + +func (t *testUnmarshalMapKeyComplex) UnmarshalText(b []byte) error { + t.Foo = string(b) + return nil +} +func (t *testUnmarshalMapKeyComplex) UnmarshalDynamoDBAttributeValue(av types.AttributeValue) error { + avM, ok := av.(*types.AttributeValueMemberM) + if !ok { + return fmt.Errorf("unexpected AttributeValue type %T, %v", av, av) + } + avFoo, ok := avM.Value["foo"] + if !ok { + return nil + } + + avS, ok := avFoo.(*types.AttributeValueMemberS) + if !ok { + return fmt.Errorf("unexpected Foo AttributeValue type, %T, %v", avM, avM) + } + + t.Foo = avS.Value + + return nil +} + +func TestUnmarshalTime_S_SS(t *testing.T) { + type A struct { + TimeField time.Time + TimeFields []time.Time + TimeFieldsL []time.Time + } + cases := map[string]struct { + input string + expect time.Time + decodeTimeS func(string) (time.Time, error) + }{ + "String RFC3339Nano (Default)": { + input: "1970-01-01T00:02:03.01Z", + expect: time.Unix(123, 10000000).UTC(), + }, + "String UnixDate": { + input: "Thu Jan 1 00:02:03 UTC 1970", + expect: time.Unix(123, 0).UTC(), + decodeTimeS: func(v string) (time.Time, error) { + t, err := time.Parse(time.UnixDate, v) + if err != nil { + return time.Time{}, &UnmarshalError{Err: err, Value: v, Type: timeType} + } + return t, nil + }, + }, + "String RFC3339 millis keeping zeroes": { + input: "1970-01-01T00:02:03.010Z", + expect: time.Unix(123, 10000000).UTC(), + decodeTimeS: func(v string) (time.Time, error) { + t, err := time.Parse("2006-01-02T15:04:05.000Z07:00", v) + if err != nil { + return time.Time{}, &UnmarshalError{Err: err, Value: v, Type: timeType} + } + return t, nil + }, + }, + "String RFC822": { + input: "01 Jan 70 00:02 UTC", + expect: time.Unix(120, 0).UTC(), + decodeTimeS: func(v string) (time.Time, error) { + t, err := time.Parse(time.RFC822, v) + if err != nil { + return time.Time{}, &UnmarshalError{Err: err, Value: v, Type: timeType} + } + return t, nil + }, + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + inputMap := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "TimeField": &types.AttributeValueMemberS{Value: c.input}, + "TimeFields": &types.AttributeValueMemberSS{Value: []string{c.input}}, + "TimeFieldsL": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: c.input}, + }}, + }, + } + expectedValue := A{ + TimeField: c.expect, + TimeFields: []time.Time{c.expect}, + TimeFieldsL: []time.Time{c.expect}, + } + + var actualValue A + if err := UnmarshalWithOptions(inputMap, &actualValue, func(options *DecoderOptions) { + if c.decodeTimeS != nil { + options.DecodeTime.S = c.decodeTimeS + } + }); err != nil { + t.Errorf("expect no error, got %v", err) + } + if diff := cmpDiff(expectedValue, actualValue); diff != "" { + t.Errorf("expect attribute value match\n%s", diff) + } + }) + } +} + +func TestUnmarshalTime_N_NS(t *testing.T) { + type A struct { + TimeField time.Time + TimeFields []time.Time + TimeFieldsL []time.Time + } + cases := map[string]struct { + input string + expect time.Time + decodeTimeN func(string) (time.Time, error) + }{ + "Number Unix seconds (Default)": { + input: "123", + expect: time.Unix(123, 0), + }, + "Number Unix milli": { + input: "123010", + expect: time.Unix(123, 10000000), + decodeTimeN: func(v string) (time.Time, error) { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return time.Time{}, &UnmarshalError{ + Err: err, Value: v, Type: timeType, + } + } + return time.Unix(0, n*int64(time.Millisecond)), nil + }, + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + inputMap := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "TimeField": &types.AttributeValueMemberN{Value: c.input}, + "TimeFields": &types.AttributeValueMemberNS{Value: []string{c.input}}, + "TimeFieldsL": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberN{Value: c.input}, + }}, + }, + } + expectedValue := A{ + TimeField: c.expect, + TimeFields: []time.Time{c.expect}, + TimeFieldsL: []time.Time{c.expect}, + } + + var actualValue A + if err := UnmarshalWithOptions(inputMap, &actualValue, func(options *DecoderOptions) { + if c.decodeTimeN != nil { + options.DecodeTime.N = c.decodeTimeN + } + }); err != nil { + t.Errorf("expect no error, got %v", err) + } + if diff := cmpDiff(expectedValue, actualValue); diff != "" { + t.Errorf("expect attribute value match\n%s", diff) + } + }) + } +} + +func TestCustomDecodeSAndDefaultDecodeN(t *testing.T) { + type A struct { + TimeFieldS time.Time + TimeFieldN time.Time + } + inputMap := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "TimeFieldS": &types.AttributeValueMemberS{Value: "01 Jan 70 00:02 UTC"}, + "TimeFieldN": &types.AttributeValueMemberN{Value: "123"}, + }, + } + expectedValue := A{ + TimeFieldS: time.Unix(120, 0).UTC(), + TimeFieldN: time.Unix(123, 0), // will use system's locale + } + + var actualValue A + if err := UnmarshalWithOptions(inputMap, &actualValue, func(options *DecoderOptions) { + // overriding only the S time decoder will keep the default N time decoder + options.DecodeTime.S = func(v string) (time.Time, error) { + t, err := time.Parse(time.RFC822, v) + if err != nil { + return time.Time{}, &UnmarshalError{Err: err, Value: v, Type: timeType} + } + return t, nil + } + }); err != nil { + t.Errorf("expect no error, got %v", err) + } + if diff := cmpDiff(expectedValue, actualValue); diff != "" { + t.Errorf("expect attribute value match\n%s", diff) + } +} + +func TestCustomDecodeNAndDefaultDecodeS(t *testing.T) { + type A struct { + TimeFieldS time.Time + TimeFieldN time.Time + } + inputMap := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "TimeFieldS": &types.AttributeValueMemberS{Value: "1970-01-01T00:02:03.01Z"}, + "TimeFieldN": &types.AttributeValueMemberN{Value: "123010"}, + }, + } + expectedValue := A{ + TimeFieldS: time.Unix(123, 10000000).UTC(), + TimeFieldN: time.Unix(123, 10000000), // will use system's locale + } + + var actualValue A + if err := UnmarshalWithOptions(inputMap, &actualValue, func(options *DecoderOptions) { + // overriding only the N time decoder will keep the default S time decoder + options.DecodeTime.N = func(v string) (time.Time, error) { + n, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return time.Time{}, &UnmarshalError{ + Err: err, Value: v, Type: timeType, + } + } + return time.Unix(0, n*int64(time.Millisecond)), nil + } + }); err != nil { + t.Errorf("expect no error, got %v", err) + } + if diff := cmpDiff(expectedValue, actualValue); diff != "" { + t.Errorf("expect attribute value match\n%s", diff) + } +} + +func TestUnmarshalMap_keyTypes(t *testing.T) { + type StrAlias string + type IntAlias int + type BoolAlias bool + + cases := map[string]struct { + input map[string]types.AttributeValue + expectVal interface{} + expectType func() interface{} + }{ + "string key": { + input: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[string]interface{}{} }, + expectVal: map[string]interface{}{ + "a": 123., + "b": "efg", + }, + }, + "string alias key": { + input: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[StrAlias]interface{}{} }, + expectVal: map[StrAlias]interface{}{ + "a": 123., + "b": "efg", + }, + }, + "Number key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[Number]interface{}{} }, + expectVal: map[Number]interface{}{ + Number("1"): 123., + Number("2"): "efg", + }, + }, + "int key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[int]interface{}{} }, + expectVal: map[int]interface{}{ + 1: 123., + 2: "efg", + }, + }, + "int alias key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[IntAlias]interface{}{} }, + expectVal: map[IntAlias]interface{}{ + 1: 123., + 2: "efg", + }, + }, + "bool key": { + input: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[bool]interface{}{} }, + expectVal: map[bool]interface{}{ + true: 123., + false: "efg", + }, + }, + "bool alias key": { + input: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[BoolAlias]interface{}{} }, + expectVal: map[BoolAlias]interface{}{ + true: 123., + false: "efg", + }, + }, + "textMarshaler key": { + input: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[testTextMarshaler]interface{}{} }, + expectVal: map[testTextMarshaler]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + }, + }, + "textMarshaler DDBAvMarshaler key": { + input: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + expectType: func() interface{} { return map[testUnmarshalMapKeyComplex]interface{}{} }, + expectVal: map[testUnmarshalMapKeyComplex]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + actualVal := c.expectType() + err := UnmarshalMap(c.input, &actualVal) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + t.Logf("expectType, %T", actualVal) + + if diff := cmpDiff(c.expectVal, actualVal); diff != "" { + t.Errorf("expect value match\n%s", diff) + } + }) + } +} + +func TestUnmarshalMap_keyPtrTypes(t *testing.T) { + input := map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + } + + expectVal := map[*testTextMarshaler]interface{}{ + {Foo: "1"}: 123., + {Foo: "2"}: "efg", + } + + actualVal := map[*testTextMarshaler]interface{}{} + err := UnmarshalMap(input, &actualVal) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + t.Logf("expectType, %T", actualVal) + + if e, a := len(expectVal), len(actualVal); e != a { + t.Errorf("expect %v values, got %v", e, a) + } + + for k, v := range expectVal { + var found bool + for ak, av := range actualVal { + if *k == *ak { + found = true + if diff := cmpDiff(v, av); diff != "" { + t.Errorf("expect value match\n%s", diff) + } + } + } + if !found { + t.Errorf("expect %v key not found", *k) + } + } + +} + +type textUnmarshalerString string + +func (v *textUnmarshalerString) UnmarshalText(text []byte) error { + *v = textUnmarshalerString("[[" + string(text) + "]]") + return nil +} + +func TestUnmarshalTextString(t *testing.T) { + in := &types.AttributeValueMemberS{Value: "foo"} + + var actual textUnmarshalerString + err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) { + o.UseEncodingUnmarshalers = true + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if string(actual) != "[[foo]]" { + t.Errorf("expected [[foo]], got %s", actual) + } +} + +func TestUnmarshalTextStringDisabled(t *testing.T) { + in := &types.AttributeValueMemberS{Value: "foo"} + + var actual textUnmarshalerString + err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) { + o.UseEncodingUnmarshalers = false + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if string(actual) != "foo" { + t.Errorf("expected foo, got %s", actual) + } +} + +type textUnmarshalerStruct struct { + I, J string +} + +func (v *textUnmarshalerStruct) UnmarshalText(text []byte) error { + parts := strings.Split(string(text), ";") + v.I = parts[0] + v.J = parts[1] + return nil +} + +func TestUnmarshalTextStruct(t *testing.T) { + in := &types.AttributeValueMemberS{Value: "foo;bar"} + + var actual textUnmarshalerStruct + err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) { + o.UseEncodingUnmarshalers = true + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + expected := textUnmarshalerStruct{"foo", "bar"} + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } +} + +type binaryUnmarshaler struct { + I, J byte +} + +func (v *binaryUnmarshaler) UnmarshalBinary(b []byte) error { + v.I = b[0] + v.J = b[1] + return nil +} + +func TestUnmarshalBinary(t *testing.T) { + in := &types.AttributeValueMemberB{Value: []byte{1, 2}} + + var actual binaryUnmarshaler + err := UnmarshalWithOptions(in, &actual, func(o *DecoderOptions) { + o.UseEncodingUnmarshalers = true + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + expected := binaryUnmarshaler{1, 2} + if actual != expected { + t.Errorf("expected %v, got %v", expected, actual) + } +} + +type testStringItem string + +func (t *testStringItem) UnmarshalDynamoDBAttributeValue(av types.AttributeValue) error { + v, ok := av.(*types.AttributeValueMemberS) + if !ok { + return fmt.Errorf("expecting string value") + } + *t = testStringItem(v.Value) + return nil +} + +type testNumberItem float64 + +func (t *testNumberItem) UnmarshalDynamoDBAttributeValue(av types.AttributeValue) error { + v, ok := av.(*types.AttributeValueMemberN) + if !ok { + return fmt.Errorf("expecting number value") + } + n, err := strconv.ParseFloat(v.Value, 64) + if err != nil { + return err + } + *t = testNumberItem(n) + return nil +} + +type testBinaryItem []byte + +func (t *testBinaryItem) UnmarshalDynamoDBAttributeValue(av types.AttributeValue) error { + v, ok := av.(*types.AttributeValueMemberB) + if !ok { + return fmt.Errorf("expecting binary value") + } + *t = make([]byte, len(v.Value)) + copy(*t, v.Value) + return nil +} + +type testStringSetWithUnmarshaler struct { + Strings []testStringItem `dynamodbav:",stringset"` + Numbers []testNumberItem `dynamodbav:",numberset"` + Binaries []testBinaryItem `dynamodbav:",binaryset"` +} + +func TestUnmarshalIndividualSetValues(t *testing.T) { + in := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Strings": &types.AttributeValueMemberSS{ + Value: []string{"a", "b"}, + }, + "Numbers": &types.AttributeValueMemberNS{ + Value: []string{"1", "2"}, + }, + "Binaries": &types.AttributeValueMemberBS{ + Value: [][]byte{{1, 2}, {3, 4}}, + }, + }, + } + var actual testStringSetWithUnmarshaler + err := UnmarshalWithOptions(in, &actual) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + expected := testStringSetWithUnmarshaler{ + Strings: []testStringItem{"a", "b"}, + Numbers: []testNumberItem{1, 2}, + Binaries: []testBinaryItem{{1, 2}, {3, 4}}, + } + if diff := cmpDiff(expected, actual); diff != "" { + t.Errorf("expect value match\n%s", diff) + } +} + +func TestDecodeVersion(t *testing.T) { + cases := []struct { + ft Tag + actual any + expected types.AttributeValue + error bool + }{ + { + ft: Tag{Version: true}, + actual: int(5), + expected: &types.AttributeValueMemberN{ + Value: "5", + }, + }, + { + ft: Tag{Version: true}, + actual: uint(5), + expected: &types.AttributeValueMemberN{ + Value: "5", + }, + }, + { + ft: Tag{Version: true}, + actual: float32(5), + expected: &types.AttributeValueMemberN{ + Value: "5", + }, + }, + { + ft: Tag{Version: true, AsString: true}, + actual: "", + expected: &types.AttributeValueMemberS{ + Value: "", + }, + }, + { + ft: Tag{Version: true}, + actual: "", + expected: &types.AttributeValueMemberS{ + Value: "", + }, + }, + } + for i, c := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + v := reflect.ValueOf(c.actual) + av, err := NewEncoder[any]().encode(v, c.ft) + + if !c.error && err != nil { + t.Errorf("unexpected error: %v", err) + + return + } + + if c.error && err == nil { + t.Error("expected error") + + return + } + + if diff := cmpDiff(c.expected, av); len(diff) != 0 { + t.Errorf("unexpected diff: %s", diff) + } + }) + } +} + +func TestDecoderWithConverters(t *testing.T) { + cases := []struct { + converter string + input types.AttributeValue + actual any + expected any + expectedError bool + fixedExpected bool + options []string + }{ + {converter: "bool", input: &types.AttributeValueMemberBOOL{Value: true}, actual: aws.Bool(false), expected: aws.Bool(false)}, + {converter: "bool", input: &types.AttributeValueMemberBOOL{Value: false}, actual: aws.Bool(true), expected: aws.Bool(true)}, + {converter: "uint", input: &types.AttributeValueMemberN{Value: "1"}, actual: aws.Uint(uint(0)), expected: aws.Uint(uint(0))}, + {converter: "uint8", input: &types.AttributeValueMemberN{Value: "1"}, actual: aws.Uint8(uint8(0)), expected: aws.Uint8(uint8(0))}, + {converter: "uint16", input: &types.AttributeValueMemberN{Value: "1"}, actual: aws.Uint16(uint16(0)), expected: aws.Uint16(uint16(0))}, + {converter: "uint32", input: &types.AttributeValueMemberN{Value: "1"}, actual: aws.Uint32(uint32(0)), expected: aws.Uint32(uint32(0))}, + {converter: "uint64", input: &types.AttributeValueMemberN{Value: "1"}, actual: aws.Uint64(uint64(0)), expected: aws.Uint64(uint64(0))}, + {converter: "int", input: &types.AttributeValueMemberN{Value: "-1"}, actual: aws.Int(int(0)), expected: aws.Int(int(0))}, + {converter: "int8", input: &types.AttributeValueMemberN{Value: "-1"}, actual: aws.Int8(int8(0)), expected: aws.Int8(int8(0))}, + {converter: "int16", input: &types.AttributeValueMemberN{Value: "-1"}, actual: aws.Int16(int16(0)), expected: aws.Int16(int16(0))}, + {converter: "int32", input: &types.AttributeValueMemberN{Value: "-1"}, actual: aws.Int32(int32(0)), expected: aws.Int32(int32(0))}, + {converter: "int64", input: &types.AttributeValueMemberN{Value: "-1"}, actual: aws.Int64(int64(0)), expected: aws.Int64(int64(0))}, + {converter: "float32", input: &types.AttributeValueMemberN{Value: "1.2"}, actual: aws.Float32(float32(0)), expected: aws.Float32(float32(0))}, + {converter: "float64", input: &types.AttributeValueMemberN{Value: "1.2"}, actual: aws.Float64(float64(0)), expected: aws.Float64(float64(0))}, + {converter: "time.Time", input: &types.AttributeValueMemberN{Value: "1758633434"}, actual: aws.Time(time.Time{}), expected: aws.Time(time.Time{})}, + { + converter: "time.Time", + input: &types.AttributeValueMemberS{Value: "2025-09-23T16:17:14.000+03:00"}, + actual: aws.Time(time.Time{}), + expected: func() any { + o, _ := time.Parse("2006-01-02T15:04:05.999999999Z07:00", "2025-09-23T16:17:14.000+03:00") + + return aws.Time(o) + }, + fixedExpected: true, + options: []string{"2006-01-02T15:04:05.999999999Z07:00"}, + }, + { + converter: "time.Time", + input: &types.AttributeValueMemberS{Value: "2025-09-23 16:17:14"}, + actual: aws.Time(time.Time{}), + expected: func() any { + o, _ := time.Parse("2006-01-02 15:04:05", "2025-09-23 16:17:14") + + return aws.Time(o) + }, + fixedExpected: true, + options: []string{"2006-01-02 15:04:05"}, + }, + { + converter: "json", + input: &types.AttributeValueMemberS{Value: `{"test":"test"}`}, + actual: &map[string]any{}, + expected: &map[string]any{ + "test": "test", + }, + fixedExpected: true, + options: []string{}, + }, + { + converter: "json", + input: &types.AttributeValueMemberS{Value: `[{"test":"test"}]`}, + actual: &[]any{}, + expected: &[]any{ + map[string]any{ + "test": "test", + }, + }, + fixedExpected: true, + options: []string{}, + }, + { + converter: "json", + input: &types.AttributeValueMemberS{Value: `[{"test":"test"}`}, + actual: &[]any{}, + expectedError: true, + fixedExpected: true, + options: []string{}, + }, + } + sd := NewDecoder[order]() + cd := NewDecoder[order](func(options *DecoderOptions) { + options.ConverterRegistry = converters.DefaultRegistry.Clone() + }) + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var err error + + err = cd.decode(c.input, reflect.ValueOf(c.actual), Tag{ + Converter: true, + Options: map[string][]string{ + "converter": append([]string{c.converter}, c.options...), + }, + }) + + if err == nil && c.expectedError { + t.Logf("%#+v", c.actual) + t.Logf("%#+v", c.expected) + t.Logf("%q", c.actual) + t.Logf("%q", c.expected) + t.Fatalf("expected error, got none") + } + + if err != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", err) + } + + if err != nil && c.expectedError { + return + } + + if f, ok := c.expected.(func() any); ok { + c.expected = f() + } + + if !c.fixedExpected { + err = sd.decode(c.input, reflect.ValueOf(c.expected), Tag{ + Converter: true, + Options: map[string][]string{ + "converter": append([]string{c.converter}, c.options...), + }, + }) + if err != nil { + t.Error(err) + } + } + + if diff := cmpDiff(c.expected, c.actual); diff != "" { + t.Errorf("unexpected diff: %v", diff) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/encode.go b/feature/dynamodb/entitymanager/encode.go new file mode 100644 index 00000000000..416a1b2509a --- /dev/null +++ b/feature/dynamodb/entitymanager/encode.go @@ -0,0 +1,943 @@ +package entitymanager + +import ( + "encoding" + "errors" + "fmt" + "reflect" + "strconv" + "time" + + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager/converters" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// An UnixTime provides aliasing of time.Time into a type that when marshaled +// and unmarshaled with AttributeValues it will be done so as number +// instead of string in seconds since January 1, 1970 UTC. +// +// This type is useful as an alternative to the struct Tag `unixtime` when you +// want to have your time value marshaled as Unix time in seconds into a number +// attribute type instead of the default time.RFC3339Nano. +// +// Important to note that zero value time as unixtime is not 0 seconds +// from January 1, 1970 UTC, but -62135596800. Which is seconds between +// January 1, 0001 UTC, and January 1, 0001 UTC. +// +// Also, important to note: the default UnixTime implementation of the Marshaler +// interface will marshal into an attribute of type of number; therefore, +// it may not be used as a sort key if the attribute value is of type string. Further, +// the time.RFC3339Nano format removes trailing zeros from the seconds Field +// and thus may not sort correctly once formatted. +type UnixTime time.Time + +// MarshalDynamoDBAttributeValue implements the Marshaler interface so that +// the UnixTime can be marshaled from to a AttributeValue number +// value encoded in the number of seconds since January 1, 1970 UTC. +func (e UnixTime) MarshalDynamoDBAttributeValue() (types.AttributeValue, error) { + return &types.AttributeValueMemberN{ + Value: strconv.FormatInt(time.Time(e).Unix(), 10), + }, nil +} + +// UnmarshalDynamoDBAttributeValue implements the Unmarshaler interface so that +// the UnixTime can be unmarshaled from a AttributeValue number representing +// the number of seconds since January 1, 1970 UTC. +// +// If an error parsing the AttributeValue number occurs UnmarshalError will be +// returned. +func (e *UnixTime) UnmarshalDynamoDBAttributeValue(av types.AttributeValue) error { + tv, ok := av.(*types.AttributeValueMemberN) + if !ok { + return &UnmarshalTypeError{ + Value: fmt.Sprintf("%V", av), + Type: reflect.TypeOf((*UnixTime)(nil)), + } + } + + t, err := decodeUnixTime(tv.Value) + if err != nil { + return err + } + + *e = UnixTime(t) + return nil +} + +// String calls the underlying time.Time.String to return a human readable +// representation. +func (e UnixTime) String() string { + return time.Time(e).String() +} + +// A Marshaler is an interface to provide custom marshaling of Go value types +// to AttributeValues. Use this to provide custom logic determining how a +// Go Value type should be marshaled. +// +// type CustomIntType struct { +// Value Int +// } +// func (m *CustomIntType) MarshalDynamoDBAttributeValue() (types.AttributeValue, error) { +// return &types.AttributeValueMemberN{ +// Value: strconv.Itoa(m.Value), +// }, nil +// } +type Marshaler interface { + MarshalDynamoDBAttributeValue() (types.AttributeValue, error) +} + +// Marshal will serialize the passed in Go value type into a AttributeValue +// type. This value can be used in API operations to simplify marshaling +// your Go value types into AttributeValues. +// +// Marshal will recursively transverse the passed in value marshaling its +// contents into a AttributeValue. Marshal supports basic scalars +// (int,uint,float,bool,string), maps, slices, and structs. Anonymous +// nested types are flattened based on Go anonymous type visibility. +// +// Marshaling slices to AttributeValue will default to a List for all +// types except for []byte and [][]byte. []byte will be marshaled as +// Binary data (B), and [][]byte will be marshaled as binary data set +// (BS). +// +// The `time.Time` type is marshaled as `time.RFC3339Nano` format. +// +// `dynamodbav` struct Tag can be used to control how the value will be +// marshaled into a AttributeValue. +// +// // Field is ignored +// Field int `dynamodbav:"-"` +// +// // Field AttributeValue map key "myName" +// Field int `dynamodbav:"myName"` +// +// // Field AttributeValue map key "myName", and +// // Field is omitted if the Field is a zero value for the type. +// Field int `dynamodbav:"myName,omitempty"` +// +// // Field AttributeValue map key "Field", and +// // Field is omitted if the Field is a zero value for the type. +// Field int `dynamodbav:",omitempty"` +// +// // Field's elems will be omitted if the elem's value is empty. +// // only valid for slices, and maps. +// Field []string `dynamodbav:",omitemptyelem"` +// +// // Field AttributeValue map key "Field", and +// // Field is sent as NULL if the Field is a zero value for the type. +// Field int `dynamodbav:",nullempty"` +// +// // Field's elems will be sent as NULL if the elem's value a zero value +// // for the type. Only valid for slices, and maps. +// Field []string `dynamodbav:",nullemptyelem"` +// +// // Field will be marshaled as a AttributeValue string +// // only value for number types, (int,uint,float) +// Field int `dynamodbav:",string"` +// +// // Field will be marshaled as a binary set +// Field [][]byte `dynamodbav:",binaryset"` +// +// // Field will be marshaled as a number set +// Field []int `dynamodbav:",numberset"` +// +// // Field will be marshaled as a string set +// Field []string `dynamodbav:",stringset"` +// +// // Field will be marshaled as Unix time number in seconds. +// // This Tag is only valid with time.Time typed struct fields. +// // Important to note that zero value time as unixtime is not 0 seconds +// // from January 1, 1970 UTC, but -62135596800. Which is seconds between +// // January 1, 0001 UTC, and January 1, 0001 UTC. +// Field time.Time `dynamodbav:",unixtime"` +// +// The omitempty Tag is only used during Marshaling and is ignored for +// Unmarshal. omitempty will skip any member if the Go value of the member is +// zero. The omitemptyelem Tag works the same as omitempty except it applies to +// the elements of maps and slices instead of struct fields, and will not be +// included in the marshaled AttributeValue Map, List, or Set. +// +// The nullempty Tag is only used during Marshaling and is ignored for +// Unmarshal. nullempty will serialize a AttributeValueMemberNULL for the +// member if the Go value of the member is zero. nullemptyelem Tag works the +// same as nullempty except it applies to the elements of maps and slices +// instead of struct fields, and will not be included in the marshaled +// AttributeValue Map, List, or Set. +// +// All struct fields and with anonymous fields, are marshaled unless the +// any of the following conditions are meet. +// +// - the Field is not exported +// - json or dynamodbav Field Tag is "-" +// - json or dynamodbav Field Tag specifies "omitempty", and is a zero value. +// +// Pointer and interfaces values are encoded as the value pointed to or +// contained in the interface. A nil value encodes as the AttributeValue NULL +// value unless `omitempty` struct Tag is provided. +// +// Channel, complex, and function values are not encoded and will be skipped +// when walking the value to be marshaled. +// +// Error that occurs when marshaling will stop the marshal, and return +// the error. +// +// Marshal cannot represent cyclic data structures and will not handle them. +// Passing cyclic structures to Marshal will result in an infinite recursion. +func Marshal[T any](in T) (types.AttributeValue, error) { + return NewEncoder[T]().Encode(in) +} + +// MarshalWithOptions will serialize the passed in Go value type into a AttributeValue +// type, by using . This value can be used in API operations to simplify marshaling +// your Go value types into AttributeValues. +// +// Use the `optsFns` functional options to override the default configuration. +// +// MarshalWithOptions will recursively transverse the passed in value marshaling its +// contents into a AttributeValue. Marshal supports basic scalars +// (int,uint,float,bool,string), maps, slices, and structs. Anonymous +// nested types are flattened based on Go anonymous type visibility. +// +// Marshaling slices to AttributeValue will default to a List for all +// types except for []byte and [][]byte. []byte will be marshaled as +// Binary data (B), and [][]byte will be marshaled as binary data set +// (BS). +// +// The `time.Time` type is marshaled as `time.RFC3339Nano` format. +// +// `dynamodbav` struct Tag can be used to control how the value will be +// marshaled into a AttributeValue. +// +// // Field is ignored +// Field int `dynamodbav:"-"` +// +// // Field AttributeValue map key "myName" +// Field int `dynamodbav:"myName"` +// +// // Field AttributeValue map key "myName", and +// // Field is omitted if the Field is a zero value for the type. +// Field int `dynamodbav:"myName,omitempty"` +// +// // Field AttributeValue map key "Field", and +// // Field is omitted if the Field is a zero value for the type. +// Field int `dynamodbav:",omitempty"` +// +// // Field's elems will be omitted if the elem's value is empty. +// // only valid for slices, and maps. +// Field []string `dynamodbav:",omitemptyelem"` +// +// // Field AttributeValue map key "Field", and +// // Field is sent as NULL if the Field is a zero value for the type. +// Field int `dynamodbav:",nullempty"` +// +// // Field's elems will be sent as NULL if the elem's value a zero value +// // for the type. Only valid for slices, and maps. +// Field []string `dynamodbav:",nullemptyelem"` +// +// // Field will be marshaled as a AttributeValue string +// // only value for number types, (int,uint,float) +// Field int `dynamodbav:",string"` +// +// // Field will be marshaled as a binary set +// Field [][]byte `dynamodbav:",binaryset"` +// +// // Field will be marshaled as a number set +// Field []int `dynamodbav:",numberset"` +// +// // Field will be marshaled as a string set +// Field []string `dynamodbav:",stringset"` +// +// // Field will be marshaled as Unix time number in seconds. +// // This Tag is only valid with time.Time typed struct fields. +// // Important to note that zero value time as unixtime is not 0 seconds +// // from January 1, 1970 UTC, but -62135596800. Which is seconds between +// // January 1, 0001 UTC, and January 1, 0001 UTC. +// Field time.Time `dynamodbav:",unixtime"` +// +// The omitempty Tag is only used during Marshaling and is ignored for +// Unmarshal. omitempty will skip any member if the Go value of the member is +// zero. The omitemptyelem Tag works the same as omitempty except it applies to +// the elements of maps and slices instead of struct fields, and will not be +// included in the marshaled AttributeValue Map, List, or Set. +// +// The nullempty Tag is only used during Marshaling and is ignored for +// Unmarshal. nullempty will serialize a AttributeValueMemberNULL for the +// member if the Go value of the member is zero. nullemptyelem Tag works the +// same as nullempty except it applies to the elements of maps and slices +// instead of struct fields, and will not be included in the marshaled +// AttributeValue Map, List, or Set. +// +// All struct fields and with anonymous fields, are marshaled unless the +// any of the following conditions are meet. +// +// - the Field is not exported +// - json or dynamodbav Field Tag is "-" +// - json or dynamodbav Field Tag specifies "omitempty", and is a zero value. +// +// Pointer and interfaces values are encoded as the value pointed to or +// contained in the interface. A nil value encodes as the AttributeValue NULL +// value unless `omitempty` struct Tag is provided. +// +// Channel, complex, and function values are not encoded and will be skipped +// when walking the value to be marshaled. +// +// Error that occurs when marshaling will stop the marshal, and return +// the error. +// +// MarshalWithOptions cannot represent cyclic data structures and will not handle them. +// Passing cyclic structures to Marshal will result in an infinite recursion. +func MarshalWithOptions[T any](in T, optFns ...func(*EncoderOptions)) (types.AttributeValue, error) { + return NewEncoder[T](optFns...).Encode(in) +} + +// MarshalMap is an alias for Marshal func which marshals Go value type to a +// map of AttributeValues. If the in parameter does not serialize to a map, an +// empty AttributeValue map will be returned. +// +// Use the `optsFns` functional options to override the default configuration. +// +// This is useful for APIs such as PutItem. +func MarshalMap[T any](in T) (map[string]types.AttributeValue, error) { + av, err := NewEncoder[T]().Encode(in) + + asMap, ok := av.(*types.AttributeValueMemberM) + if err != nil || av == nil || !ok { + return map[string]types.AttributeValue{}, err + } + + return asMap.Value, nil +} + +// MarshalMapWithOptions is an alias for MarshalWithOptions func which marshals Go value type to a +// map of AttributeValues. If the in parameter does not serialize to a map, an +// empty AttributeValue map will be returned. +// +// Use the `optsFns` functional options to override the default configuration. +// +// This is useful for APIs such as PutItem. +func MarshalMapWithOptions[T any](in T, optFns ...func(*EncoderOptions)) (map[string]types.AttributeValue, error) { + av, err := NewEncoder[T](optFns...).Encode(in) + + asMap, ok := av.(*types.AttributeValueMemberM) + if err != nil || av == nil || !ok { + return map[string]types.AttributeValue{}, err + } + + return asMap.Value, nil +} + +// MarshalList is an alias for Marshal func which marshals Go value +// type to a slice of AttributeValues. If the in parameter does not serialize +// to a slice, an empty AttributeValue slice will be returned. +func MarshalList[T any](in T) ([]types.AttributeValue, error) { + av, err := NewEncoder[T]().Encode(in) + + asList, ok := av.(*types.AttributeValueMemberL) + if err != nil || av == nil || !ok { + return []types.AttributeValue{}, err + } + + return asList.Value, nil +} + +// MarshalListWithOptions is an alias for MarshalWithOptions func which marshals Go value +// type to a slice of AttributeValues. If the in parameter does not serialize +// to a slice, an empty AttributeValue slice will be returned. +// +// Use the `optsFns` functional options to override the default configuration. +func MarshalListWithOptions[T any](in any, optFns ...func(*EncoderOptions)) ([]types.AttributeValue, error) { + av, err := NewEncoder[T](optFns...).Encode(in) + + asList, ok := av.(*types.AttributeValueMemberL) + if err != nil || av == nil || !ok { + return []types.AttributeValue{}, err + } + + return asList.Value, nil +} + +// EncoderOptions is a collection of options used by the marshaler. +type EncoderOptions struct { + // Support other custom struct Tag keys, such as `yaml`, `json`, or `toml`. + // Note that values provided with a custom TagKey must also be supported + // by the (un)marshalers in this package. + // + // Tag key `dynamodbav` will always be read, but if custom Tag key + // conflicts with `dynamodbav` the custom Tag key value will be used. + TagKey string + + // Will encode any slice being encoded as a set (SS, NS, and BS) as a NULL + // AttributeValue if the slice is not nil, but is empty but contains no + // elements. + // + // If a type implements the Marshal interface, and returns empty set + // slices, this option will not modify the returned value. + // + // Defaults to enabled, because AttributeValue sets cannot currently be + // empty lists. + NullEmptySets bool + + // Will encode time.Time fields + // + // Default encoding is time.RFC3339Nano in a DynamoDB String (S) data type. + EncodeTime func(time.Time) (types.AttributeValue, error) + + // When enabled, the encoder will omit empty time attribute values + OmitEmptyTime bool + + // IgnoreNilValueErrors controls whether decoding should ignore errors + // caused by nil values during schema conversion. + // If true, fields with nil values that cause conversion errors will be skipped. + // If false or nil, such cases will trigger an error. + IgnoreNilValueErrors *bool + + // ConverterRegistry provides a registry of type converters used during + // encoding and decoding operations. It will be set on both the Decoder + // and Encoder to control how values are transformed between Go types + // and schema representations. + ConverterRegistry *converters.Registry +} + +// An Encoder provides marshaling Go value types to AttributeValues. +type Encoder[T any] struct { + options EncoderOptions +} + +// NewEncoder creates a new Encoder with default configuration. Use +// the `opts` functional options to override the default configuration. +func NewEncoder[T any](optFns ...func(*EncoderOptions)) *Encoder[T] { + options := EncoderOptions{ + TagKey: defaultTagKey, + NullEmptySets: true, + EncodeTime: defaultEncodeTime, + } + for _, fn := range optFns { + fn(&options) + } + + if options.EncodeTime == nil { + options.EncodeTime = defaultEncodeTime + } + + return &Encoder[T]{ + options: options, + } +} + +// Encode will marshal a Go value type to an AttributeValue. Returning +// the AttributeValue constructed or error. +func (e *Encoder[T]) Encode(in interface{}) (types.AttributeValue, error) { + return e.encode(reflect.ValueOf(in), Tag{}) +} + +func (e *Encoder[T]) encode(v reflect.Value, fieldTag Tag) (types.AttributeValue, error) { + // Ignore fields explicitly marked to be skipped. + if fieldTag.Ignore { + return nil, nil + } + + if e.options.ConverterRegistry != nil && fieldTag.Converter { + el := valueElem(v) + cvtName := el.Type().String() + + opts, ok := fieldTag.Option("converter") + if ok { + cvtName = opts[0] + } + + if cvt := e.options.ConverterRegistry.Converter(cvtName); cvt != nil { + av, err := cvt.ToAttributeValue(el.Interface(), opts) + + if errors.Is(converters.ErrNilValue, err) && !unwrap(e.options.IgnoreNilValueErrors) { + err = nil + } + + return av, err + } + } + + // Zero values are serialized as null, or skipped if omitEmpty. + if isZeroValue(v) { + if fieldTag.OmitEmpty && fieldTag.NullEmpty { + return nil, &InvalidMarshalError{ + msg: "unable to encode AttributeValue for zero value field with incompatible struct tags, omitempty and nullempty"} + } + + if fieldTag.OmitEmpty { + return nil, nil + } else if isNullableZeroValue(v) || fieldTag.NullEmpty { + return encodeNull(), nil + } + } + + // Handle both pointers and interface conversion into types + v = valueElem(v) + + if v.Kind() != reflect.Invalid { + // time.Time implements too many interfaces so we handle is a special case + if t, ok := v.Interface().(time.Time); ok { + if fieldTag.OmitEmpty && t.IsZero() { + return nil, nil + } + + if fieldTag.AsUnixTime { + return UnixTime(t).MarshalDynamoDBAttributeValue() + } + + if e.options.EncodeTime != nil { + return e.options.EncodeTime(t) + } + + return defaultEncodeTime(t) + } else if av, err := e.tryMarshaler(v); err != nil { + return nil, err + } else if fieldTag.OmitEmpty && isNullAttributeValue(av) { + return nil, nil + } else if av != nil { + return av, nil + } + } + + switch v.Kind() { + case reflect.Invalid: + if fieldTag.OmitEmpty { + return nil, nil + } + // Handle case where member type needed to be dereferenced and resulted + // in a kind that is invalid. + return encodeNull(), nil + + case reflect.Struct: + return e.encodeStruct(v, fieldTag) + + case reflect.Map: + return e.encodeMap(v, fieldTag) + + case reflect.Slice, reflect.Array: + return e.encodeSlice(v, fieldTag) + + case reflect.Chan, reflect.Func, reflect.UnsafePointer: + // skip unsupported types + return nil, nil + + default: + return e.encodeScalar(v, fieldTag) + } +} + +func (e *Encoder[T]) encodeStruct(v reflect.Value, fieldTag Tag) (types.AttributeValue, error) { + // Time structs have no public members, and instead are converted to + // RFC3339Nano formatted string, unix time seconds number if struct Tag is set. + if v.Type().ConvertibleTo(timeType) { + var t time.Time + t = v.Convert(timeType).Interface().(time.Time) + + if e.options.OmitEmptyTime && fieldTag.OmitEmpty && t.IsZero() { + return nil, nil + } + + if fieldTag.AsUnixTime { + return UnixTime(t).MarshalDynamoDBAttributeValue() + } + return e.options.EncodeTime(t) + } + + m := &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{}} + fields := unionStructFields(v.Type(), structFieldOptions{ + TagKey: e.options.TagKey, + }) + for _, f := range fields.All() { + if f.Name == "" { + return nil, &InvalidMarshalError{msg: "map key cannot be empty"} + } + + fv, found := encoderFieldByIndex(v, f.Index) + if !found { + continue + } + + elem, err := e.encode(fv, f.Tag) + if err != nil { + return nil, err + } else if elem == nil { + continue + } + + m.Value[f.Name] = elem + } + + return m, nil +} + +func (e *Encoder[T]) encodeMap(v reflect.Value, fieldTag Tag) (types.AttributeValue, error) { + m := &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{}} + for _, key := range v.MapKeys() { + keyName, err := mapKeyAsString(key, fieldTag) + if err != nil { + return nil, err + } + + elemVal := v.MapIndex(key) + elem, err := e.encode(elemVal, Tag{ + OmitEmpty: fieldTag.OmitEmptyElem, + NullEmpty: fieldTag.NullEmptyElem, + }) + if err != nil { + return nil, err + } else if elem == nil { + continue + } + + m.Value[keyName] = elem + } + + return m, nil +} + +func mapKeyAsString(keyVal reflect.Value, fieldTag Tag) (keyStr string, err error) { + defer func() { + if err != nil { + return + } + if keyStr == "" { + err = &InvalidMarshalError{msg: "map key cannot be empty"} + } + }() + + if k, ok := keyVal.Interface().(encoding.TextMarshaler); ok { + b, err := k.MarshalText() + if err != nil { + return "", fmt.Errorf("failed to marshal text, %w", err) + } + return string(b), err + } + + switch keyVal.Kind() { + case reflect.Bool, + reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + + return fmt.Sprint(keyVal.Interface()), nil + + default: + return "", &InvalidMarshalError{ + msg: "map key type not supported, must be string, number, bool, or TextMarshaler", + } + } +} + +func (e *Encoder[T]) encodeSlice(v reflect.Value, fieldTag Tag) (types.AttributeValue, error) { + if v.Type().Elem().Kind() == reflect.Uint8 { + slice := reflect.MakeSlice(byteSliceType, v.Len(), v.Len()) + reflect.Copy(slice, v) + + return &types.AttributeValueMemberB{ + Value: append([]byte{}, slice.Bytes()...), + }, nil + } + + var setElemFn func(types.AttributeValue) error + var av types.AttributeValue + + if fieldTag.AsBinSet || v.Type() == byteSliceSliceType { // Binary Set + if v.Len() == 0 && e.options.NullEmptySets { + return encodeNull(), nil + } + + bs := &types.AttributeValueMemberBS{Value: make([][]byte, 0, v.Len())} + av = bs + setElemFn = func(elem types.AttributeValue) error { + b, ok := elem.(*types.AttributeValueMemberB) + if !ok || b == nil || b.Value == nil { + return &InvalidMarshalError{ + msg: "binary set must only contain non-nil byte slices"} + } + bs.Value = append(bs.Value, b.Value) + return nil + } + + } else if fieldTag.AsNumSet { // Number Set + if v.Len() == 0 && e.options.NullEmptySets { + return encodeNull(), nil + } + + ns := &types.AttributeValueMemberNS{Value: make([]string, 0, v.Len())} + av = ns + setElemFn = func(elem types.AttributeValue) error { + n, ok := elem.(*types.AttributeValueMemberN) + if !ok || n == nil { + return &InvalidMarshalError{ + msg: "number set must only contain non-nil string numbers"} + } + ns.Value = append(ns.Value, n.Value) + return nil + } + + } else if fieldTag.AsStrSet { // String Set + if v.Len() == 0 && e.options.NullEmptySets { + return encodeNull(), nil + } + + ss := &types.AttributeValueMemberSS{Value: make([]string, 0, v.Len())} + av = ss + setElemFn = func(elem types.AttributeValue) error { + s, ok := elem.(*types.AttributeValueMemberS) + if !ok || s == nil { + return &InvalidMarshalError{ + msg: "string set must only contain non-nil strings"} + } + ss.Value = append(ss.Value, s.Value) + return nil + } + + } else { // List + l := &types.AttributeValueMemberL{Value: make([]types.AttributeValue, 0, v.Len())} + av = l + setElemFn = func(elem types.AttributeValue) error { + l.Value = append(l.Value, elem) + return nil + } + } + + if err := e.encodeListElems(v, fieldTag, setElemFn); err != nil { + return nil, err + } + + return av, nil +} + +func (e *Encoder[T]) encodeListElems(v reflect.Value, fieldTag Tag, setElem func(types.AttributeValue) error) error { + for i := 0; i < v.Len(); i++ { + elem, err := e.encode(v.Index(i), Tag{ + OmitEmpty: fieldTag.OmitEmptyElem, + NullEmpty: fieldTag.NullEmptyElem, + }) + if err != nil { + return err + } else if elem == nil { + continue + } + + if err := setElem(elem); err != nil { + return err + } + } + + return nil +} + +// Returns if the type of the value satisfies an interface for number like the +// encoding/json#Number and feature/dynamodb/attributevalue#Number +func isNumberValueType(v reflect.Value) bool { + type numberer interface { + Float64() (float64, error) + Int64() (int64, error) + String() string + } + + _, ok := v.Interface().(numberer) + return ok && v.Kind() == reflect.String +} + +func (e *Encoder[T]) encodeScalar(v reflect.Value, fieldTag Tag) (types.AttributeValue, error) { + if isNumberValueType(v) { + if fieldTag.AsString { + return &types.AttributeValueMemberS{Value: v.String()}, nil + } + return &types.AttributeValueMemberN{Value: v.String()}, nil + } + + switch v.Kind() { + case reflect.Bool: + return &types.AttributeValueMemberBOOL{Value: v.Bool()}, nil + + case reflect.String: + return e.encodeString(v) + + default: + // Fallback to encoding numbers, will return invalid type if not supported + av, err := e.encodeNumber(v) + if err != nil { + return nil, err + } + + n, isNumber := av.(*types.AttributeValueMemberN) + if fieldTag.AsString && isNumber { + return &types.AttributeValueMemberS{Value: n.Value}, nil + } + return av, nil + } +} + +func (e *Encoder[T]) encodeNumber(v reflect.Value) (types.AttributeValue, error) { + + var out string + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out = encodeInt(v.Int()) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + out = encodeUint(v.Uint()) + + case reflect.Float32: + out = encodeFloat(v.Float(), 32) + + case reflect.Float64: + out = encodeFloat(v.Float(), 64) + + default: + return nil, nil + } + + return &types.AttributeValueMemberN{Value: out}, nil +} + +func (e *Encoder[T]) encodeString(v reflect.Value) (types.AttributeValue, error) { + + switch v.Kind() { + case reflect.String: + s := v.String() + return &types.AttributeValueMemberS{Value: s}, nil + + default: + return nil, nil + } +} + +func encodeInt(i int64) string { + return strconv.FormatInt(i, 10) +} +func encodeUint(u uint64) string { + return strconv.FormatUint(u, 10) +} +func encodeFloat(f float64, bitSize int) string { + return strconv.FormatFloat(f, 'f', -1, bitSize) +} +func encodeNull() types.AttributeValue { + return &types.AttributeValueMemberNULL{Value: true} +} + +// encoderFieldByIndex finds the Field with the provided nested index +func encoderFieldByIndex(v reflect.Value, index []int) (reflect.Value, bool) { + for i, x := range index { + if i > 0 && v.Kind() == reflect.Ptr && v.Type().Elem().Kind() == reflect.Struct { + if v.IsNil() { + return reflect.Value{}, false + } + v = v.Elem() + } + v = v.Field(x) + } + return v, true +} + +func valueElem(v reflect.Value) reflect.Value { + switch v.Kind() { + case reflect.Interface, reflect.Ptr: + for v.Kind() == reflect.Interface || v.Kind() == reflect.Ptr { + v = v.Elem() + } + } + + return v +} + +func isZeroValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Invalid: + return true + case reflect.Array: + return v.Len() == 0 + case reflect.Map, reflect.Slice: + return v.IsNil() + case reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return false +} + +func isNullableZeroValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Invalid: + return true + case reflect.Map, reflect.Slice: + return v.IsNil() + case reflect.Interface, reflect.Ptr: + return v.IsNil() + } + return false +} + +func (e *Encoder[T]) tryMarshaler(v reflect.Value) (types.AttributeValue, error) { + if v.Kind() != reflect.Ptr && v.Type().Name() != "" && v.CanAddr() { + v = v.Addr() + } + + if v.Type().NumMethod() == 0 { + return nil, nil + } + + i := v.Interface() + if m, ok := i.(Marshaler); ok { + return m.MarshalDynamoDBAttributeValue() + } + + return e.tryEncodingMarshaler(i) +} + +func (e *Encoder[T]) tryEncodingMarshaler(v any) (types.AttributeValue, error) { + if m, ok := v.(encoding.TextMarshaler); ok { + s, err := m.MarshalText() + if err != nil { + return nil, err + } + + return &types.AttributeValueMemberS{Value: string(s)}, nil + } + + if m, ok := v.(encoding.BinaryMarshaler); ok { + b, err := m.MarshalBinary() + if err != nil { + return nil, err + } + + return &types.AttributeValueMemberB{Value: b}, nil + } + + return nil, nil +} + +// An InvalidMarshalError is an error type representing an error +// occurring when marshaling a Go value type to an AttributeValue. +type InvalidMarshalError struct { + msg string +} + +// Error returns the string representation of the error. +// satisfying the error interface +func (e *InvalidMarshalError) Error() string { + return fmt.Sprintf("marshal failed, %s", e.msg) +} + +func defaultEncodeTime(t time.Time) (types.AttributeValue, error) { + return &types.AttributeValueMemberS{ + Value: t.Format(time.RFC3339Nano), + }, nil +} + +func isNullAttributeValue(av types.AttributeValue) bool { + n, ok := av.(*types.AttributeValueMemberNULL) + return ok && n.Value +} diff --git a/feature/dynamodb/entitymanager/encode_test.go b/feature/dynamodb/entitymanager/encode_test.go new file mode 100644 index 00000000000..54f4d5b56e0 --- /dev/null +++ b/feature/dynamodb/entitymanager/encode_test.go @@ -0,0 +1,1046 @@ +package entitymanager + +import ( + "fmt" + "reflect" + "strconv" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager/converters" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestMarshalShared(t *testing.T) { + for name, c := range sharedTestCases { + t.Run(name, func(t *testing.T) { + av, err := Marshal(c.expected) + assertConvertTest(t, av, c.in, err, c.err) + }) + } +} + +func TestMarshalListShared(t *testing.T) { + for name, c := range sharedListTestCases { + t.Run(name, func(t *testing.T) { + av, err := MarshalList(c.expected) + assertConvertTest(t, av, c.in, err, c.err) + }) + } +} + +func TestMarshalMapShared(t *testing.T) { + for name, c := range sharedMapTestCases { + t.Run(name, func(t *testing.T) { + av, err := MarshalMap(c.expected) + assertConvertTest(t, av, c.in, err, c.err) + }) + } +} + +type marshalMarshaler struct { + Value string + Value2 int + Value3 bool + Value4 time.Time +} + +func (m *marshalMarshaler) MarshalDynamoDBAttributeValue() (types.AttributeValue, error) { + return &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "abc": &types.AttributeValueMemberS{Value: m.Value}, + "def": &types.AttributeValueMemberN{Value: strconv.Itoa(m.Value2)}, + "ghi": &types.AttributeValueMemberBOOL{Value: m.Value3}, + "jkl": &types.AttributeValueMemberS{Value: m.Value4.Format(time.RFC3339Nano)}, + }, + }, nil +} + +func TestMarshalMashaler(t *testing.T) { + m := &marshalMarshaler{ + Value: "value", + Value2: 123, + Value3: true, + Value4: testDate, + } + + expect := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "abc": &types.AttributeValueMemberS{Value: "value"}, + "def": &types.AttributeValueMemberN{Value: "123"}, + "ghi": &types.AttributeValueMemberBOOL{Value: true}, + "jkl": &types.AttributeValueMemberS{Value: "2016-05-03T17:06:26.209072Z"}, + }, + } + + actual, err := Marshal(m) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + + if e, a := expect, actual; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } +} + +type customBoolStringMarshaler string + +func (m customBoolStringMarshaler) MarshalDynamoDBAttributeValue() (types.AttributeValue, error) { + + if b, err := strconv.ParseBool(string(m)); err == nil { + return &types.AttributeValueMemberBOOL{Value: b}, nil + } + + return &types.AttributeValueMemberS{Value: string(m)}, nil +} + +type customTextMarshaler struct { + I, J int +} + +func (v customTextMarshaler) MarshalText() ([]byte, error) { + text := fmt.Sprintf("{I: %d, J: %d}", v.I, v.J) + return []byte(text), nil +} + +type customBinaryMarshaler struct { + I, J byte +} + +func (v customBinaryMarshaler) MarshalBinary() ([]byte, error) { + return []byte{v.I, v.J}, nil +} + +type customAVAndTextMarshaler struct { + I, J int +} + +func (v customAVAndTextMarshaler) MarshalDynamoDBAttributeValue() (types.AttributeValue, error) { + return &types.AttributeValueMemberNS{Value: []string{ + fmt.Sprintf("%d", v.I), + fmt.Sprintf("%d", v.J), + }}, nil +} + +func (v customAVAndTextMarshaler) MarshalText() ([]byte, error) { + return []byte("should never happen"), nil +} + +func TestEncodingMarshalers(t *testing.T) { + cases := []struct { + input any + expected types.AttributeValue + }{ + { + input: customTextMarshaler{1, 2}, + expected: &types.AttributeValueMemberS{Value: "{I: 1, J: 2}"}, + }, + { + input: customBinaryMarshaler{1, 2}, + expected: &types.AttributeValueMemberB{Value: []byte{1, 2}}, + }, + { + input: customAVAndTextMarshaler{1, 2}, + expected: &types.AttributeValueMemberNS{Value: []string{"1", "2"}}, + }, + } + + for _, testCase := range cases { + actual, err := MarshalWithOptions(testCase.input) + if err != nil { + t.Errorf("got unexpected error %v for input %v", err, testCase.input) + } + if diff := cmpDiff(testCase.expected, actual); len(diff) != 0 { + t.Errorf("expected match but got: %s", diff) + } + } +} + +func TestCustomStringMarshaler(t *testing.T) { + cases := []struct { + expected types.AttributeValue + input string + }{ + { + expected: &types.AttributeValueMemberBOOL{Value: false}, + input: "false", + }, + { + expected: &types.AttributeValueMemberBOOL{Value: true}, + input: "true", + }, + { + expected: &types.AttributeValueMemberS{Value: "ABC"}, + input: "ABC", + }, + } + + for _, testCase := range cases { + input := customBoolStringMarshaler(testCase.input) + actual, err := Marshal(input) + if err != nil { + t.Errorf("got unexpected error %v for input %v", err, testCase.input) + } + if diff := cmpDiff(testCase.expected, actual); len(diff) != 0 { + t.Errorf("expected match but got:%s", diff) + } + } +} + +type customGradeMarshaler uint + +func (m customGradeMarshaler) MarshalDynamoDBAttributeValue() (types.AttributeValue, error) { + if int(m) > 100 { + return nil, fmt.Errorf("grade cant be larger then 100") + } + return &types.AttributeValueMemberN{Value: strconv.FormatUint(uint64(m), 10)}, nil +} + +func TestCustomNumberMarshaler(t *testing.T) { + cases := []struct { + expectedErr bool + input uint + expected types.AttributeValue + }{ + { + expectedErr: false, + input: 50, + expected: &types.AttributeValueMemberN{Value: "50"}, + }, + { + expectedErr: false, + input: 90, + expected: &types.AttributeValueMemberN{Value: "90"}, + }, + { + expectedErr: true, + input: 150, + expected: nil, + }, + } + + for _, testCase := range cases { + input := customGradeMarshaler(testCase.input) + actual, err := Marshal(customGradeMarshaler(input)) + if testCase.expectedErr && err == nil { + t.Errorf("expected error but got nil for input %v", testCase.input) + continue + } + if !testCase.expectedErr && err != nil { + t.Errorf("got unexpected error %v for input %v", err, testCase.input) + continue + } + if diff := cmpDiff(testCase.expected, actual); len(diff) != 0 { + t.Errorf("expected match but got:%s", diff) + } + } +} + +type testOmitEmptyElemListStruct struct { + Values []string `dynamodbav:",omitemptyelem"` +} + +type testOmitEmptyElemMapStruct struct { + Values map[string]interface{} `dynamodbav:",omitemptyelem"` +} + +func TestMarshalListOmitEmptyElem(t *testing.T) { + expect := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Values": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "abc"}, + &types.AttributeValueMemberS{Value: "123"}, + }}, + }, + } + + m := testOmitEmptyElemListStruct{Values: []string{"abc", "", "123"}} + + actual, err := Marshal(m) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + if diff := cmpDiff(expect, actual); len(diff) != 0 { + t.Errorf("expect match\n%s", diff) + } +} + +func TestMarshalMapOmitEmptyElem(t *testing.T) { + expect := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Values": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "abc": &types.AttributeValueMemberN{Value: "123"}, + "hij": &types.AttributeValueMemberS{Value: ""}, + "klm": &types.AttributeValueMemberS{Value: "abc"}, + "qrs": &types.AttributeValueMemberS{Value: "abc"}, + }}, + }, + } + + m := testOmitEmptyElemMapStruct{Values: map[string]interface{}{ + "abc": 123., + "efg": nil, + "hij": "", + "klm": "abc", + "nop": func() interface{} { + var v *string + return v + }(), + "qrs": func() interface{} { + v := "abc" + return &v + }(), + }} + + actual, err := Marshal(m) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + if diff := cmpDiff(expect, actual); len(diff) != 0 { + t.Errorf("expect match\n%s", diff) + } +} + +type testNullEmptyElemListStruct struct { + Values []string `dynamodbav:",nullemptyelem"` +} + +type testNullEmptyElemMapStruct struct { + Values map[string]interface{} `dynamodbav:",nullemptyelem"` +} + +func TestMarshalListNullEmptyElem(t *testing.T) { + expect := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Values": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "abc"}, + &types.AttributeValueMemberNULL{Value: true}, + &types.AttributeValueMemberS{Value: "123"}, + }}, + }, + } + + m := testNullEmptyElemListStruct{Values: []string{"abc", "", "123"}} + + actual, err := Marshal(m) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + if diff := cmpDiff(expect, actual); len(diff) != 0 { + t.Errorf("expect match\n%s", diff) + } +} + +func TestMarshalMapNullEmptyElem(t *testing.T) { + expect := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Values": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "abc": &types.AttributeValueMemberN{Value: "123"}, + "efg": &types.AttributeValueMemberNULL{Value: true}, + "hij": &types.AttributeValueMemberS{Value: ""}, + "klm": &types.AttributeValueMemberS{Value: "abc"}, + "nop": &types.AttributeValueMemberNULL{Value: true}, + "qrs": &types.AttributeValueMemberS{Value: "abc"}, + }}, + }, + } + + m := testNullEmptyElemMapStruct{Values: map[string]interface{}{ + "abc": 123., + "efg": nil, + "hij": "", + "klm": "abc", + "nop": func() interface{} { + var v *string + return v + }(), + "qrs": func() interface{} { + v := "abc" + return &v + }(), + }} + + actual, err := Marshal(m) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + if diff := cmpDiff(expect, actual); len(diff) != 0 { + t.Errorf("expect match\n%s", diff) + } +} + +type testOmitEmptyScalar struct { + IntZero int `dynamodbav:",omitempty"` + IntPtrNil *int `dynamodbav:",omitempty"` + IntPtrSetZero *int `dynamodbav:",omitempty"` +} + +func TestMarshalOmitEmpty(t *testing.T) { + expect := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "IntPtrSetZero": &types.AttributeValueMemberN{Value: "0"}, + }, + } + + m := testOmitEmptyScalar{IntPtrSetZero: aws.Int(0)} + + actual, err := Marshal(m) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + if e, a := expect, actual; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } +} + +type customNullMarshaler struct{} + +func (m customNullMarshaler) MarshalDynamoDBAttributeValue() (types.AttributeValue, error) { + return &types.AttributeValueMemberNULL{Value: true}, nil +} + +type testOmitEmptyCustom struct { + CustomNullOmit customNullMarshaler `dynamodbav:",omitempty"` + CustomNullOmitTagKey customNullMarshaler `tagkey:",omitempty"` + CustomNullPresent customNullMarshaler + EmptySetOmit []string `dynamodbav:",omitempty"` +} + +func TestMarshalOmitEmptyCustom(t *testing.T) { + expect := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "CustomNullPresent": &types.AttributeValueMemberNULL{Value: true}, + }, + } + + m := testOmitEmptyCustom{} + + actual, err := MarshalWithOptions(m, func(eo *EncoderOptions) { + eo.TagKey = "tagkey" + eo.NullEmptySets = true + }) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + if e, a := expect, actual; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestEncodeEmbeddedPointerStruct(t *testing.T) { + type B struct { + Bint int + } + type C struct { + Cint int + } + type A struct { + Aint int + *B + *C + } + a := A{Aint: 321, B: &B{123}} + if e, a := 321, a.Aint; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := 123, a.Bint; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if a.C != nil { + t.Errorf("expect nil, got %v", a.C) + } + + actual, err := Marshal(a) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + expect := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Aint": &types.AttributeValueMemberN{Value: "321"}, + "Bint": &types.AttributeValueMemberN{Value: "123"}, + }, + } + if e, a := expect, actual; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestEncodeUnixTime(t *testing.T) { + type A struct { + Normal time.Time + Tagged time.Time `dynamodbav:",unixtime"` + Typed UnixTime + } + + a := A{ + Normal: time.Unix(123, 0).UTC(), + Tagged: time.Unix(456, 0), + Typed: UnixTime(time.Unix(789, 0)), + } + + actual, err := Marshal(a) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + expect := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Normal": &types.AttributeValueMemberS{Value: "1970-01-01T00:02:03Z"}, + "Tagged": &types.AttributeValueMemberN{Value: "456"}, + "Typed": &types.AttributeValueMemberN{Value: "789"}, + }, + } + if e, a := expect, actual; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestUnixTimeString(t *testing.T) { + gotime := time.Date(2016, time.May, 03, 17, 06, 26, 0, time.UTC) + ddbtime := UnixTime(gotime) + if fmt.Sprint(gotime) != fmt.Sprint(ddbtime) { + t.Error("UnixTime.String not equal to time.Time.String") + } +} + +type AliasedTime time.Time + +func TestEncodeAliasedUnixTime(t *testing.T) { + type A struct { + Normal AliasedTime + Tagged AliasedTime `dynamodbav:",unixtime"` + } + + a := A{ + Normal: AliasedTime(time.Unix(123, 0).UTC()), + Tagged: AliasedTime(time.Unix(456, 0)), + } + + actual, err := Marshal(a) + if err != nil { + t.Errorf("expect no err, got %v", err) + } + expect := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Normal": &types.AttributeValueMemberS{Value: "1970-01-01T00:02:03Z"}, + "Tagged": &types.AttributeValueMemberN{Value: "456"}, + }, + } + if e, a := expect, actual; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestMarshalTime_S(t *testing.T) { + type A struct { + TimeField time.Time + TimeFieldsL []time.Time + } + cases := map[string]struct { + input time.Time + expect string + encodeTime func(time.Time) (types.AttributeValue, error) + }{ + "String RFC3339Nano (Default)": { + input: time.Unix(123, 10000000).UTC(), + expect: "1970-01-01T00:02:03.01Z", + }, + "String UnixDate": { + input: time.Unix(123, 0).UTC(), + expect: "Thu Jan 1 00:02:03 UTC 1970", + encodeTime: func(t time.Time) (types.AttributeValue, error) { + return &types.AttributeValueMemberS{ + Value: t.Format(time.UnixDate), + }, nil + }, + }, + "String RFC3339 millis keeping zeroes": { + input: time.Unix(123, 10000000).UTC(), + expect: "1970-01-01T00:02:03.010Z", + encodeTime: func(t time.Time) (types.AttributeValue, error) { + return &types.AttributeValueMemberS{ + Value: t.Format("2006-01-02T15:04:05.000Z07:00"), // Would be RFC3339 millis with zeroes + }, nil + }, + }, + "String RFC822": { + input: time.Unix(120, 0).UTC(), + expect: "01 Jan 70 00:02 UTC", + encodeTime: func(t time.Time) (types.AttributeValue, error) { + return &types.AttributeValueMemberS{ + Value: t.Format(time.RFC822), + }, nil + }, + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + inputValue := A{ + TimeField: c.input, + TimeFieldsL: []time.Time{c.input}, + } + actual, err := MarshalWithOptions(inputValue, func(eo *EncoderOptions) { + if c.encodeTime != nil { + eo.EncodeTime = c.encodeTime + } + }) + if err != nil { + t.Errorf("expect no err, got %v", err) + } + expectedValue := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "TimeField": &types.AttributeValueMemberS{Value: c.expect}, + "TimeFieldsL": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: c.expect}, + }}, + }, + } + if diff := cmpDiff(expectedValue, actual); diff != "" { + t.Errorf("expect attribute value match\n%s", diff) + } + }) + } +} + +func TestMarshalTime_N(t *testing.T) { + type A struct { + TimeField time.Time + TimeFieldsL []time.Time + } + cases := map[string]struct { + input time.Time + expect string + encodeTime func(time.Time) (types.AttributeValue, error) + }{ + "Number Unix seconds": { + input: time.Unix(123, 10000000).UTC(), + expect: "123", + encodeTime: func(t time.Time) (types.AttributeValue, error) { + return &types.AttributeValueMemberN{ + Value: strconv.Itoa(int(t.Unix())), + }, nil + }, + }, + "Number Unix milli": { + input: time.Unix(123, 10000000).UTC(), + expect: "123010", + encodeTime: func(t time.Time) (types.AttributeValue, error) { + return &types.AttributeValueMemberN{ + Value: strconv.Itoa(int(t.UnixNano() / int64(time.Millisecond))), + }, nil + }, + }, + } + for name, c := range cases { + t.Run(name, func(t *testing.T) { + inputValue := A{ + TimeField: c.input, + TimeFieldsL: []time.Time{c.input}, + } + actual, err := MarshalWithOptions(inputValue, func(eo *EncoderOptions) { + if c.encodeTime != nil { + eo.EncodeTime = c.encodeTime + } + }) + if err != nil { + t.Errorf("expect no err, got %v", err) + } + expectedValue := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "TimeField": &types.AttributeValueMemberN{Value: c.expect}, + "TimeFieldsL": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberN{Value: c.expect}, + }}, + }, + } + if diff := cmpDiff(expectedValue, actual); diff != "" { + t.Errorf("expect attribute value match\n%s", diff) + } + }) + } +} + +func TestEncoderFieldByIndex(t *testing.T) { + type ( + Middle struct{ Inner int } + Outer struct{ *Middle } + ) + + // nil embedded struct + outer := Outer{} + outerFields := unionStructFields(reflect.TypeOf(outer), structFieldOptions{}) + innerField, _ := outerFields.FieldByName("Inner") + + _, found := encoderFieldByIndex(reflect.ValueOf(&outer).Elem(), innerField.Index) + if found != false { + t.Error("expected found to be false when embedded struct is nil") + } + + // non-nil embedded struct + outer = Outer{Middle: &Middle{Inner: 3}} + outerFields = unionStructFields(reflect.TypeOf(outer), structFieldOptions{}) + innerField, _ = outerFields.FieldByName("Inner") + + f, found := encoderFieldByIndex(reflect.ValueOf(&outer).Elem(), innerField.Index) + if !found { + t.Error("expected found to be true") + } + if f.Kind() != reflect.Int || f.Int() != int64(outer.Inner) { + t.Error("expected f to be of kind Int with value equal to outer.Inner") + } +} + +func TestMarshalMap_keyTypes(t *testing.T) { + type StrAlias string + type IntAlias int + type BoolAlias bool + + cases := map[string]struct { + input interface{} + expectAV map[string]types.AttributeValue + }{ + "string key": { + input: map[string]interface{}{ + "a": 123, + "b": "efg", + }, + expectAV: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "string alias key": { + input: map[StrAlias]interface{}{ + "a": 123, + "b": "efg", + }, + expectAV: map[string]types.AttributeValue{ + "a": &types.AttributeValueMemberN{Value: "123"}, + "b": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "Number key": { + input: map[Number]interface{}{ + Number("1"): 123, + Number("2"): "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "int key": { + input: map[int]interface{}{ + 1: 123, + 2: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "int alias key": { + input: map[IntAlias]interface{}{ + 1: 123, + 2: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "1": &types.AttributeValueMemberN{Value: "123"}, + "2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "bool key": { + input: map[bool]interface{}{ + true: 123, + false: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "bool alias key": { + input: map[BoolAlias]interface{}{ + true: 123, + false: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberN{Value: "123"}, + "false": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "textMarshaler key": { + input: map[testTextMarshaler]interface{}{ + {Foo: "1"}: 123, + {Foo: "2"}: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + "textMarshaler pointer key": { + input: map[*testTextMarshaler]interface{}{ + {Foo: "1"}: 123, + {Foo: "2"}: "efg", + }, + expectAV: map[string]types.AttributeValue{ + "Foo:1": &types.AttributeValueMemberN{Value: "123"}, + "Foo:2": &types.AttributeValueMemberS{Value: "efg"}, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + av, err := MarshalMap(c.input) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if diff := cmpDiff(c.expectAV, av); diff != "" { + t.Errorf("expect attribute value match\n%s", diff) + } + }) + } +} + +func TestEncodeEmptyTime(t *testing.T) { + type A struct { + Created time.Time `dynamodbav:"created,omitempty"` + } + + a := A{Created: time.Time{}} + + actual, err := MarshalWithOptions(a, func(o *EncoderOptions) { + o.OmitEmptyTime = true + }) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + + expect := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{}, + } + + if e, a := expect, actual; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } + + actual2, err := MarshalMapWithOptions(a, func(o *EncoderOptions) { + o.OmitEmptyTime = true + }) + if err != nil { + t.Errorf("expect nil, got %v", err) + } + + expect2 := map[string]types.AttributeValue{} + + if e, a := expect2, actual2; !reflect.DeepEqual(e, a) { + t.Errorf("expect %v, got %v", e, a) + } +} + +func TestEncodeVersion(t *testing.T) { + cases := []struct { + ft Tag + actual any + expected types.AttributeValue + error bool + }{ + { + ft: Tag{Version: true}, + actual: int(5), + expected: &types.AttributeValueMemberN{ + Value: "5", + }, + }, + { + ft: Tag{Version: true}, + actual: uint(5), + expected: &types.AttributeValueMemberN{ + Value: "5", + }, + }, + { + ft: Tag{Version: true}, + actual: float32(5), + expected: &types.AttributeValueMemberN{ + Value: "5", + }, + }, + { + ft: Tag{Version: true, AsString: true}, + actual: "", + expected: &types.AttributeValueMemberS{ + Value: "", + }, + }, + { + ft: Tag{Version: true}, + actual: "", + expected: &types.AttributeValueMemberS{ + Value: "", + }, + }, + } + for i, c := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + v := reflect.ValueOf(c.actual) + av, err := NewEncoder[any]().encode(v, c.ft) + + if !c.error && err != nil { + t.Errorf("unexpected error: %v", err) + + return + } + + if c.error && err == nil { + t.Error("expected error") + + return + } + + if diff := cmpDiff(av, c.expected); len(diff) != 0 { + t.Errorf("unexpected diff: %s", diff) + + fmt.Printf("%#+v\n", av) + fmt.Printf("%#+v\n", c.expected) + } + }) + } +} + +func TestEncodeWithConverter(t *testing.T) { + cases := []struct { + converter string + input any + expected types.AttributeValue + expectedError bool + fixedExpected bool + options []string + }{ + {converter: "bool", expected: &types.AttributeValueMemberBOOL{Value: true}, input: reflect.ValueOf(false)}, + {converter: "bool", expected: &types.AttributeValueMemberBOOL{Value: false}, input: reflect.ValueOf(true)}, + {converter: "uint", expected: &types.AttributeValueMemberN{Value: "1"}, input: reflect.ValueOf(uint(0))}, + {converter: "uint8", expected: &types.AttributeValueMemberN{Value: "1"}, input: reflect.ValueOf(uint8(0))}, + {converter: "uint16", expected: &types.AttributeValueMemberN{Value: "1"}, input: reflect.ValueOf(uint16(0))}, + {converter: "uint32", expected: &types.AttributeValueMemberN{Value: "1"}, input: reflect.ValueOf(uint32(0))}, + {converter: "uint64", expected: &types.AttributeValueMemberN{Value: "1"}, input: reflect.ValueOf(uint64(0))}, + {converter: "int", expected: &types.AttributeValueMemberN{Value: "-1"}, input: reflect.ValueOf(int(0))}, + {converter: "int8", expected: &types.AttributeValueMemberN{Value: "-1"}, input: reflect.ValueOf(int8(0))}, + {converter: "int16", expected: &types.AttributeValueMemberN{Value: "-1"}, input: reflect.ValueOf(int16(0))}, + {converter: "int32", expected: &types.AttributeValueMemberN{Value: "-1"}, input: reflect.ValueOf(int32(0))}, + {converter: "int64", expected: &types.AttributeValueMemberN{Value: "-1"}, input: reflect.ValueOf(int64(0))}, + {converter: "float32", expected: &types.AttributeValueMemberN{Value: "1.2"}, input: reflect.ValueOf(float32(0))}, + {converter: "float64", expected: &types.AttributeValueMemberN{Value: "1.2"}, input: reflect.ValueOf(float64(0))}, + {converter: "time.Time", expected: &types.AttributeValueMemberN{Value: "1758633434"}, input: reflect.ValueOf(time.Time{})}, + { + converter: "time.Time", + expected: &types.AttributeValueMemberS{Value: "2025-09-23T16:17:14+03:00"}, + input: func() reflect.Value { + o, _ := time.Parse("2006-01-02T15:04:05.999999999Z07:00", "2025-09-23T16:17:14+03:00") + + return reflect.ValueOf(o) + }, + fixedExpected: true, + options: []string{"2006-01-02T15:04:05Z07:00"}, + }, + { + converter: "time.Time", + expected: &types.AttributeValueMemberS{Value: "2025-09-23T16:17:14Z"}, + input: func() reflect.Value { + o, _ := time.Parse("2006-01-02T15:04:05Z", "2025-09-23T16:17:14Z") + + return reflect.ValueOf(o) + }, + fixedExpected: true, + options: []string{"2006-01-02T15:04:05Z"}, + }, + { + converter: "json", + expected: &types.AttributeValueMemberS{Value: `{"test":"test"}`}, + input: reflect.ValueOf(map[string]any{ + "test": "test", + }), + fixedExpected: true, + options: []string{}, + }, + { + converter: "json", + expected: &types.AttributeValueMemberS{Value: `[{"test":"test"}]`}, + input: reflect.ValueOf([]any{ + map[string]any{ + "test": "test", + }, + }), + fixedExpected: true, + options: []string{}, + }, + { + converter: "json", + expected: &types.AttributeValueMemberS{Value: "null"}, + input: reflect.ValueOf((map[string]string)(nil)), + expectedError: false, + fixedExpected: true, + options: []string{}, + }, + } + se := NewEncoder[order]() + ce := NewEncoder[order](func(options *EncoderOptions) { + options.ConverterRegistry = converters.DefaultRegistry.Clone() + }) + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var err error + var actual types.AttributeValue + var input reflect.Value + + switch v := c.input.(type) { + case reflect.Value: + input = v + case func() reflect.Value: + input = v() + default: + t.Error("unexpected test case input type") + + return + } + + actual, err = ce.encode(input, Tag{ + Converter: true, + Options: map[string][]string{ + "converter": append([]string{c.converter}, c.options...), + }, + }) + + if err == nil && c.expectedError { + t.Fatalf("expected error, got none") + } + + if err != nil && !c.expectedError { + t.Fatalf("unexpected error, got: %v", err) + } + + if err != nil && c.expectedError { + return + } + + if !c.fixedExpected { + c.expected, err = se.encode(input, Tag{ + Converter: true, + Options: map[string][]string{ + "converter": append([]string{c.converter}, c.options...), + }, + }) + if err != nil { + t.Error(err) + } + } + + if diff := cmpDiff(c.expected, actual); diff != "" { + t.Errorf("unexpected diff: %v", diff) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/examples/crud/go.mod b/feature/dynamodb/entitymanager/examples/crud/go.mod new file mode 100644 index 00000000000..bc7c7855e6f --- /dev/null +++ b/feature/dynamodb/entitymanager/examples/crud/go.mod @@ -0,0 +1,30 @@ +module canaries + +go 1.23.0 + +replace github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager => ../.. + +require ( + github.com/aws/aws-sdk-go-v2 v1.39.6 + github.com/aws/aws-sdk-go-v2/config v1.31.17 + github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager v0.0.0-00010101000000-000000000000 + github.com/aws/aws-sdk-go-v2/service/dynamodb v1.52.2 +) + +require ( + github.com/aws/aws-sdk-go-v2/credentials v1.18.21 // indirect + github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.19 // indirect + github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.19 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.11 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 // indirect + github.com/aws/smithy-go v1.23.2 // indirect +) diff --git a/feature/dynamodb/entitymanager/examples/crud/go.sum b/feature/dynamodb/entitymanager/examples/crud/go.sum new file mode 100644 index 00000000000..03c8e196b8d --- /dev/null +++ b/feature/dynamodb/entitymanager/examples/crud/go.sum @@ -0,0 +1,36 @@ +github.com/aws/aws-sdk-go-v2 v1.39.6 h1:2JrPCVgWJm7bm83BDwY5z8ietmeJUbh3O2ACnn+Xsqk= +github.com/aws/aws-sdk-go-v2 v1.39.6/go.mod h1:c9pm7VwuW0UPxAEYGyTmyurVcNrbF6Rt/wixFqDhcjE= +github.com/aws/aws-sdk-go-v2/config v1.31.17 h1:QFl8lL6RgakNK86vusim14P2k8BFSxjvUkcWLDjgz9Y= +github.com/aws/aws-sdk-go-v2/config v1.31.17/go.mod h1:V8P7ILjp/Uef/aX8TjGk6OHZN6IKPM5YW6S78QnRD5c= +github.com/aws/aws-sdk-go-v2/credentials v1.18.21 h1:56HGpsgnmD+2/KpG0ikvvR8+3v3COCwaF4r+oWwOeNA= +github.com/aws/aws-sdk-go-v2/credentials v1.18.21/go.mod h1:3YELwedmQbw7cXNaII2Wywd+YY58AmLPwX4LzARgmmA= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.19 h1:HmfJKLqMk6nVmyZrScNXVlnqfhIeiAcsJozT1Md8pBI= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.19/go.mod h1:BVQAm94IwIMmbNGwd7inlFczhZl75gIQWK7SejQPSRA= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.19 h1:zgzPsusBAei9dD+PUb+26W7Ju6q/MM8+SrSCL7abJ54= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.19/go.mod h1:oQ3PiGmB6gdUDAO6y4XAOVkG/biM7qMxI/522eZLjMc= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 h1:T1brd5dR3/fzNFAQch/iBKeX07/ffu/cLu+q+RuzEWk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13/go.mod h1:Peg/GBAQ6JDt+RoBf4meB1wylmAipb7Kg2ZFakZTlwk= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 h1:a+8/MLcWlIxo1lF9xaGt3J/u3yOZx+CdSveSNwjhD40= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13/go.mod h1:oGnKwIYZ4XttyU2JWxFrwvhF6YKiK/9/wmE3v3Iu9K8= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 h1:HBSI2kDkMdWz4ZM7FjwE7e/pWDEZ+nR95x8Ztet1ooY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13/go.mod h1:YE94ZoDArI7awZqJzBAZ3PDD2zSfuP7w6P2knOzIn8M= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.52.2 h1:v63QYOleHhBT1SctUsl4RXH+yjYuxQzpGxFRfjCmXBc= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.52.2/go.mod h1:OU+zHNgIjScCe8j2GAZ7uEWVMH3UupqAp2c2gpyckEE= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.0 h1:ccmQULuINm6Yj9ynQY5+6rnDnGXCVQnWh5aqVDec+K8= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.0/go.mod h1:kPSrLRdnPrs1oEl7B5f6DInj2kpv3ePyh/Ow22zXlrw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 h1:x2Ibm/Af8Fi+BH+Hsn9TXGdT+hKbDd5XOTZxTMxDk7o= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3/go.mod h1:IW1jwyrQgMdhisceG8fQLmQIydcT/jWY21rFhzgaKwo= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.11 h1:E+Q3COWEOkzzxo3kxG6zUskB3qsNMG/+UWbuREq5b9M= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.11/go.mod h1:p2NzdJjY5n+i+BAf9iw5jZRURdplXLX47IRB8LP2AgQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 h1:kDqdFvMY4AtKoACfzIGD8A0+hbT41KTKF//gq7jITfM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13/go.mod h1:lmKuogqSU3HzQCwZ9ZtcqOc5XGMqtDK7OIc2+DxiUEg= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 h1:0JPwLz1J+5lEOfy/g0SURC9cxhbQ1lIMHMa+AHZSzz0= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.1/go.mod h1:fKvyjJcz63iL/ftA6RaM8sRCtN4r4zl4tjL3qw5ec7k= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 h1:OWs0/j2UYR5LOGi88sD5/lhN6TDLG6SfA7CqsQO9zF0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5/go.mod h1:klO+ejMvYsB4QATfEOIXk8WAEwN4N0aBfJpvC+5SZBo= +github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 h1:mLlUgHn02ue8whiR4BmxxGJLR2gwU6s6ZzJ5wDamBUs= +github.com/aws/aws-sdk-go-v2/service/sts v1.39.1/go.mod h1:E19xDjpzPZC7LS2knI9E6BaRFDK43Eul7vd6rSq2HWk= +github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM= +github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= diff --git a/feature/dynamodb/entitymanager/examples/crud/main.go b/feature/dynamodb/entitymanager/examples/crud/main.go new file mode 100644 index 00000000000..3db25e01c28 --- /dev/null +++ b/feature/dynamodb/entitymanager/examples/crud/main.go @@ -0,0 +1,150 @@ +package main + +import ( + "context" + "fmt" + "log" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" +) + +// Item example struct +type Item struct { + ID string `dynamodbav:"id,partition"` + Email string `dynamodbav:"email,sort"` + Name string `dynamodbav:"name"` + Body string `dynamodbav:"body"` +} + +func main() { + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + panic(err) + } + + tableName := fmt.Sprintf("table_%s", time.Now().Format("2006_01_02_15_04_05")) + + ddb := dynamodb.NewFromConfig(cfg) + + sch, err := entitymanager.NewSchema[Item]() + if err != nil { + panic(err) + } + + sch = sch.WithTableName(aws.String(tableName)) + + tbl, err := entitymanager.NewTable[Item](ddb, func(options *entitymanager.TableOptions[Item]) { + options.Schema = sch + }) + if err != nil { + panic(err) + } + + if exists, err := tbl.Exists(context.Background()); !exists || err != nil { + if err != nil { + panic(err) + } + + if err := tbl.CreateWithWait(context.Background(), time.Minute*2); err != nil { + panic(err) + } + + defer func() { + if err := tbl.DeleteWithWait(context.Background(), time.Minute*2); err != nil { + panic(err) + } + }() + } + + log.Print("PutItem() up to 10") + for c := range 10 { + i, err := tbl.PutItem(context.Background(), &Item{ + ID: fmt.Sprintf("%d", c), + Email: fmt.Sprintf("user-%d@amazon.dev.null", c), + Name: fmt.Sprintf("First%d", c), + Body: fmt.Sprintf("Last%d", c), + }) + if err != nil { + log.Printf("Error putting item: %v", err) + } + if i != nil { + log.Printf("Put item %#+v", i) + } + } + + log.Print("GetItem() up to 10") + for c := range 10 { + m := entitymanager.Map{}. + With("id", fmt.Sprintf("%d", c)). + With("email", fmt.Sprintf("user-%d@amazon.dev.null", c)) + + i, err := tbl.GetItem( + context.Background(), + m, + ) + if err != nil { + log.Printf("Error getting item %v: %v", m, err) + } + if i != nil { + log.Printf("Got item %#+v", i) + } + } + + log.Print("Query()") + { + keyCond := expression.Key("id").Equal(expression.Value("1")) //. + // And(expression.Key("email").BeginsWith("user")) + + expr, err := expression.NewBuilder().WithKeyCondition(keyCond).Build() + if err != nil { + panic(err) + } + + for res := range tbl.Query(context.Background(), expr) { + if res.Error() != nil { + log.Printf("error: %v", res.Error()) + + continue + } + + log.Printf("Got item %#+v", res.Item()) + } + } + + log.Print("Scan()") + { + f := expression.Name("id").Contains("1"). + And(expression.Name("email").Contains("user")) + + expr, err := expression.NewBuilder().WithFilter(f).Build() + if err != nil { + panic(err) + } + + total := 0 + for res := range tbl.Scan(context.Background(), expr) { + if res.Error() != nil { + log.Printf("Scan() error: %v", res.Error()) + + continue + } + + id := res.Item().ID + if strings.Contains(id, "1") { + log.Printf("Got item %#+v", id) + } else { + log.Printf("Error: %v", id) + } + + total++ + } + + log.Printf("Total: %d", total) + } +} diff --git a/feature/dynamodb/entitymanager/examples/single-table/go.mod b/feature/dynamodb/entitymanager/examples/single-table/go.mod new file mode 100644 index 00000000000..bc7c7855e6f --- /dev/null +++ b/feature/dynamodb/entitymanager/examples/single-table/go.mod @@ -0,0 +1,30 @@ +module canaries + +go 1.23.0 + +replace github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager => ../.. + +require ( + github.com/aws/aws-sdk-go-v2 v1.39.6 + github.com/aws/aws-sdk-go-v2/config v1.31.17 + github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager v0.0.0-00010101000000-000000000000 + github.com/aws/aws-sdk-go-v2/service/dynamodb v1.52.2 +) + +require ( + github.com/aws/aws-sdk-go-v2/credentials v1.18.21 // indirect + github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.19 // indirect + github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.19 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.11 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 // indirect + github.com/aws/smithy-go v1.23.2 // indirect +) diff --git a/feature/dynamodb/entitymanager/examples/single-table/go.sum b/feature/dynamodb/entitymanager/examples/single-table/go.sum new file mode 100644 index 00000000000..03c8e196b8d --- /dev/null +++ b/feature/dynamodb/entitymanager/examples/single-table/go.sum @@ -0,0 +1,36 @@ +github.com/aws/aws-sdk-go-v2 v1.39.6 h1:2JrPCVgWJm7bm83BDwY5z8ietmeJUbh3O2ACnn+Xsqk= +github.com/aws/aws-sdk-go-v2 v1.39.6/go.mod h1:c9pm7VwuW0UPxAEYGyTmyurVcNrbF6Rt/wixFqDhcjE= +github.com/aws/aws-sdk-go-v2/config v1.31.17 h1:QFl8lL6RgakNK86vusim14P2k8BFSxjvUkcWLDjgz9Y= +github.com/aws/aws-sdk-go-v2/config v1.31.17/go.mod h1:V8P7ILjp/Uef/aX8TjGk6OHZN6IKPM5YW6S78QnRD5c= +github.com/aws/aws-sdk-go-v2/credentials v1.18.21 h1:56HGpsgnmD+2/KpG0ikvvR8+3v3COCwaF4r+oWwOeNA= +github.com/aws/aws-sdk-go-v2/credentials v1.18.21/go.mod h1:3YELwedmQbw7cXNaII2Wywd+YY58AmLPwX4LzARgmmA= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.19 h1:HmfJKLqMk6nVmyZrScNXVlnqfhIeiAcsJozT1Md8pBI= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.19/go.mod h1:BVQAm94IwIMmbNGwd7inlFczhZl75gIQWK7SejQPSRA= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.19 h1:zgzPsusBAei9dD+PUb+26W7Ju6q/MM8+SrSCL7abJ54= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.19/go.mod h1:oQ3PiGmB6gdUDAO6y4XAOVkG/biM7qMxI/522eZLjMc= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 h1:T1brd5dR3/fzNFAQch/iBKeX07/ffu/cLu+q+RuzEWk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13/go.mod h1:Peg/GBAQ6JDt+RoBf4meB1wylmAipb7Kg2ZFakZTlwk= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 h1:a+8/MLcWlIxo1lF9xaGt3J/u3yOZx+CdSveSNwjhD40= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13/go.mod h1:oGnKwIYZ4XttyU2JWxFrwvhF6YKiK/9/wmE3v3Iu9K8= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 h1:HBSI2kDkMdWz4ZM7FjwE7e/pWDEZ+nR95x8Ztet1ooY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13/go.mod h1:YE94ZoDArI7awZqJzBAZ3PDD2zSfuP7w6P2knOzIn8M= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.52.2 h1:v63QYOleHhBT1SctUsl4RXH+yjYuxQzpGxFRfjCmXBc= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.52.2/go.mod h1:OU+zHNgIjScCe8j2GAZ7uEWVMH3UupqAp2c2gpyckEE= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.0 h1:ccmQULuINm6Yj9ynQY5+6rnDnGXCVQnWh5aqVDec+K8= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.0/go.mod h1:kPSrLRdnPrs1oEl7B5f6DInj2kpv3ePyh/Ow22zXlrw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 h1:x2Ibm/Af8Fi+BH+Hsn9TXGdT+hKbDd5XOTZxTMxDk7o= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3/go.mod h1:IW1jwyrQgMdhisceG8fQLmQIydcT/jWY21rFhzgaKwo= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.11 h1:E+Q3COWEOkzzxo3kxG6zUskB3qsNMG/+UWbuREq5b9M= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.11/go.mod h1:p2NzdJjY5n+i+BAf9iw5jZRURdplXLX47IRB8LP2AgQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 h1:kDqdFvMY4AtKoACfzIGD8A0+hbT41KTKF//gq7jITfM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13/go.mod h1:lmKuogqSU3HzQCwZ9ZtcqOc5XGMqtDK7OIc2+DxiUEg= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 h1:0JPwLz1J+5lEOfy/g0SURC9cxhbQ1lIMHMa+AHZSzz0= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.1/go.mod h1:fKvyjJcz63iL/ftA6RaM8sRCtN4r4zl4tjL3qw5ec7k= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 h1:OWs0/j2UYR5LOGi88sD5/lhN6TDLG6SfA7CqsQO9zF0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5/go.mod h1:klO+ejMvYsB4QATfEOIXk8WAEwN4N0aBfJpvC+5SZBo= +github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 h1:mLlUgHn02ue8whiR4BmxxGJLR2gwU6s6ZzJ5wDamBUs= +github.com/aws/aws-sdk-go-v2/service/sts v1.39.1/go.mod h1:E19xDjpzPZC7LS2knI9E6BaRFDK43Eul7vd6rSq2HWk= +github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM= +github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= diff --git a/feature/dynamodb/entitymanager/examples/single-table/main.go b/feature/dynamodb/entitymanager/examples/single-table/main.go new file mode 100644 index 00000000000..50a31728bc1 --- /dev/null +++ b/feature/dynamodb/entitymanager/examples/single-table/main.go @@ -0,0 +1,333 @@ +package main + +import ( + "context" + "fmt" + "log" + "math/rand" + "sort" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" +) + +type FullTable struct { + PK string `dynamodbav:"PK,partition"` + SK string `dynamodbav:"SK,sort" dynamodbindex:"GSI,global,sort"` + Result any `dynamodbav:"result,coverter|json"` + LastUpdated time.Time `dynamodbav:"LastUpdated"` + GSI int64 `dynamodbav:"GSI" dynamodbindex:"GSI,global,partition;GSI_PK_Index,local,sort"` + TS int64 `dynamodbav:"ts"` +} + +type Race struct { + RaceID string `dynamodbav:"PK,partition"` // race-%d + ClassID string `dynamodbav:"SK,sort"` // class-%d +} + +type RaceResult struct { + RaceID string `dynamodbav:"PK,partition"` // race-%d + RacerID string `dynamodbav:"SK,sort"` // racer-%d + TimeMs int64 `dynamodbav:"time_ms"` // race time in milliseconds +} + +type Racer struct { + RaceID string `dynamodbav:"PK,partition"` //. race-%d + Name string `dynamodbav:"SK,sort"` // racer-%d +} + +func getTable[T any](client entitymanager.Client, tableName string) *entitymanager.Table[T] { + sch, err := entitymanager.NewSchema[T]() + if err != nil { + panic(err) + } + + sch = sch.WithTableName(aws.String(tableName)) + + tbl, err := entitymanager.NewTable(client, func(options *entitymanager.TableOptions[T]) { + options.Schema = sch + }) + if err != nil { + panic(err) + } + + return tbl +} + +func createTable(client entitymanager.Client, tableName string) context.CancelFunc { + // create the full table with gsi and lsi and all + log.Println("Creating full table") + tbl := getTable[FullTable](client, tableName) + + if exists, err := tbl.Exists(context.Background()); !exists || err != nil { + if err != nil { + panic(err) + } + + if err := tbl.CreateWithWait(context.Background(), time.Minute*2); err != nil { + panic(err) + } + log.Println("Created full table") + + return func() { + log.Println("Deleting full table") + if err := tbl.DeleteWithWait(context.Background(), time.Minute*2); err != nil { + panic(err) + } + log.Println("Deleted full table") + } + } + + return func() {} +} + +func generateData(client entitymanager.Client, tableName string, count int) { + generateRaces(client, tableName, count) + generateRacers(client, tableName, count) + generateRaceResults(client, tableName, count) +} + +func generateRaces(client entitymanager.Client, tableName string, count int) { + racesTbl := getTable[Race](client, tableName) + + for c := range count { + _, err := racesTbl.PutItem(context.Background(), &Race{ + RaceID: fmt.Sprintf("race-%d", c), + ClassID: fmt.Sprintf("class-%d", c), + }) + + if err != nil { + log.Printf("Error writing race") + } + } +} + +func generateRacers(client entitymanager.Client, tableName string, count int) { + racersTbl := getTable[Racer](client, tableName) + + for c := range count { + _, err := racersTbl.PutItem(context.Background(), &Racer{ + RaceID: fmt.Sprintf("race-%d", c), + Name: fmt.Sprintf("name-%d", c), + }) + + if err != nil { + log.Printf("Error writing racers") + } + } +} +func generateRaceResults(client entitymanager.Client, tableName string, count int) { + raceResultsTbl := getTable[RaceResult](client, tableName) + + rand.Seed(time.Now().UnixNano()) + + for c := range count { + for d := range count { + // Simulate a race time in milliseconds (e.g. 60s–80s) + timeMs := int64(60000 + rand.Intn(20000)) + + _, err := raceResultsTbl.PutItem(context.Background(), &RaceResult{ + RaceID: fmt.Sprintf("race-%d", c), + RacerID: fmt.Sprintf("racer-%d", d), + TimeMs: timeMs, + }) + + if err != nil { + log.Printf("Error writing race result") + } + } + } +} + +func main() { + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + panic(err) + } + + tableName := fmt.Sprintf("table_%s", time.Now().Format("2006_01_02_15_04_05")) + println(tableName) + + ddb := dynamodb.NewFromConfig(cfg) + count := 10 + + cancel := createTable(ddb, tableName) + defer cancel() + + generateData(ddb, tableName, count) + + racerTbl := getTable[Racer](ddb, tableName) + _ = racerTbl + raceTbl := getTable[Race](ddb, tableName) + _ = raceTbl + raceResultTbl := getTable[RaceResult](ddb, tableName) + + ctx := context.Background() + + // show leaderboard table - max 10 rows + showLeaderboard(ctx, raceResultTbl, 10) + + // show race results from random race + showRandomRaceResults(ctx, raceResultTbl, count) + + // show rankings + showRaceRankings(ctx, raceResultTbl, 10) +} + +func showLeaderboard(ctx context.Context, raceResultTbl *entitymanager.Table[RaceResult], maxRows int) { + // Only consider RaceResult items (SK starts with "racer-") + f := expression.Name("SK").BeginsWith("racer-") + expr, err := expression.NewBuilder().WithFilter(f).Build() + if err != nil { + log.Printf("error building expression for leaderboard scan: %v", err) + return + } + + // Track best (smallest) time per racer across all races + type racerBest struct { + RacerID string + TimeMs int64 + } + + bestByRacer := map[string]int64{} + for res := range raceResultTbl.Scan(ctx, expr) { + if res.Error() != nil { + log.Printf("Scan() error while building leaderboard: %v", res.Error()) + continue + } + + item := res.Item() + if item == nil { + continue + } + + t := item.TimeMs + current, ok := bestByRacer[item.RacerID] + if !ok || t < current { + bestByRacer[item.RacerID] = t + } + } + + entries := make([]racerBest, 0, len(bestByRacer)) + for id, t := range bestByRacer { + entries = append(entries, racerBest{RacerID: id, TimeMs: t}) + } + + sort.Slice(entries, func(i, j int) bool { + return entries[i].TimeMs < entries[j].TimeMs + }) + + fmt.Println("== Leaderboard (top racers by best time) ==") + for i, e := range entries { + if i >= maxRows { + break + } + fmt.Printf("%2d. %-10s %d ms\n", i+1, e.RacerID, e.TimeMs) + } +} + +func showRandomRaceResults(ctx context.Context, raceResultTbl *entitymanager.Table[RaceResult], races int) { + if races <= 0 { + return + } + + rand.Seed(time.Now().UnixNano()) + idx := rand.Intn(races) + raceID := fmt.Sprintf("race-%d", idx) + + // Query a single race, but only RaceResult items (SK starts with "racer-") + keyCond := expression.Key("PK").Equal(expression.Value(raceID)).And( + expression.Key("SK").BeginsWith("racer-"), + ) + expr, err := expression.NewBuilder().WithKeyCondition(keyCond).Build() + if err != nil { + log.Printf("error building expression for random race query: %v", err) + return + } + + // Collect all results so we can sort by time + results := make([]*RaceResult, 0) + for res := range raceResultTbl.Query(ctx, expr) { + if res.Error() != nil { + log.Printf("Query() error for %s: %v", raceID, res.Error()) + continue + } + + item := res.Item() + if item == nil { + continue + } + + results = append(results, item) + } + + sort.Slice(results, func(i, j int) bool { + return results[i].TimeMs < results[j].TimeMs + }) + + fmt.Printf("== Results for %s ==\n", raceID) + for i, r := range results { + fmt.Printf("%2d. %-10s %d ms\n", i+1, r.RacerID, r.TimeMs) + } +} + +func showRaceRankings(ctx context.Context, raceResultTbl *entitymanager.Table[RaceResult], maxRows int) { + // Only consider RaceResult items (SK starts with "racer-") + f := expression.Name("SK").BeginsWith("racer-") + expr, err := expression.NewBuilder().WithFilter(f).Build() + if err != nil { + log.Printf("error building expression for rankings scan: %v", err) + return + } + + // For each race, track the best (smallest) time and the pilot + type raceBest struct { + RaceID string + BestTime int64 + BestRacer string + } + + bestByRace := map[string]raceBest{} + for res := range raceResultTbl.Scan(ctx, expr) { + if res.Error() != nil { + log.Printf("Scan() error while building rankings: %v", res.Error()) + continue + } + + item := res.Item() + if item == nil { + continue + } + + t := item.TimeMs + current, ok := bestByRace[item.RaceID] + if !ok || t < current.BestTime { + bestByRace[item.RaceID] = raceBest{ + RaceID: item.RaceID, + BestTime: t, + BestRacer: item.RacerID, + } + } + } + + races := make([]raceBest, 0, len(bestByRace)) + for _, v := range bestByRace { + races = append(races, v) + } + + sort.Slice(races, func(i, j int) bool { + return races[i].BestTime < races[j].BestTime + }) + + fmt.Println("== Race rankings (by best pilot time) ==") + for i, e := range races { + if i >= maxRows { + break + } + fmt.Printf("%2d. %-10s %d ms (best: %s)\n", i+1, e.RaceID, e.BestTime, e.BestRacer) + } +} diff --git a/feature/dynamodb/entitymanager/extension.go b/feature/dynamodb/entitymanager/extension.go new file mode 100644 index 00000000000..f168c165b6c --- /dev/null +++ b/feature/dynamodb/entitymanager/extension.go @@ -0,0 +1,659 @@ +package entitymanager + +import ( + "context" + cryptorand "crypto/rand" + "fmt" + "reflect" + "slices" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + smythyrand "github.com/aws/smithy-go/rand" +) + +type ( + // CachedFieldsKey is the context key for storing CachedFields in a context.Context. + CachedFieldsKey struct{} + + // TableSchemaKey is the context key for storing a TableSchema in a context.Context. + TableSchemaKey struct{} +) + +// BeforeReader is implemented by types that want to run logic +// before a value of type T is read (e.g., fetched from storage or unmarshaled). +// The hook receives the context and a pointer to the value that will be read. +type BeforeReader[T any] interface { + BeforeRead(context.Context, *T) error +} + +// AfterReader is implemented by types that want to run logic +// after a value of type T is read (e.g., fetched from storage or unmarshaled). +// The hook receives the context and a pointer to the value that was read. +type AfterReader[T any] interface { + AfterRead(context.Context, *T) error +} + +// BeforeWriter is implemented by types that want to execute logic +// before a value of type T is written (e.g., persisted or marshaled). +// The hook receives the context and a pointer to the value that will be written. +type BeforeWriter[T any] interface { + BeforeWrite(context.Context, *T) error +} + +// AfterWriter is implemented by types that want to execute logic +// after a value of type T is written (e.g., persisted or marshaled). +// The hook receives the context and a pointer to the value that was written. +type AfterWriter[T any] interface { + AfterWrite(context.Context, *T) error +} + +// ConditionExpressionBuilder allows a type to inject custom logic for building +// a DynamoDB condition expression during a write operation (UpdateItem). +// The interface embeds BeforeWriter[T] for pre-write hooks, and provides +// BuildCondition to modify or set the ConditionBuilder used in the update. +// Typical use cases include optimistic locking, version checks, or enforcing +// business rules before an update is applied. +// +// Used only in UpdateItem() to construct the ConditionExpression for the request. +type ConditionExpressionBuilder[T any] interface { + BeforeWriter[T] + BuildCondition(context.Context, *T, **expression.ConditionBuilder) error +} + +// FilterExpressionBuilder allows a type to inject custom logic for building +// a DynamoDB condition expression during a write operation (UpdateItem). +// The interface embeds BeforeWriter[T] for pre-write hooks, and provides +// BuildFilter to modify or set the ConditionBuilder used as a filter. +// This is useful for advanced filtering scenarios, though in your codebase +// it is only used in UpdateItem(). +// +// Used only in UpdateItem() to construct the FilterExpression for the request. +type FilterExpressionBuilder[T any] interface { + BeforeWriter[T] + BuildFilter(context.Context, *T, **expression.ConditionBuilder) error +} + +// KeyConditionBuilder allows a type to inject custom logic for building +// a DynamoDB condition expression during a write operation (UpdateItem). +// The interface embeds BeforeWriter[T] for pre-write hooks, and provides +// BuildKeyCondition to modify or set the KeyConditionBuilder. +// This is useful for customizing how keys are matched in conditional updates. +// +// Used only in UpdateItem() to construct the KeyConditionExpression for the request. +type KeyConditionBuilder[T any] interface { + BeforeWriter[T] + BuildKeyCondition(context.Context, *T, **expression.KeyConditionBuilder) error +} + +// ProjectionExpressionBuilder allows a type to inject custom logic for building +// a DynamoDB condition expression during a write operation (UpdateItem). +// The interface embeds BeforeWriter[T] for pre-write hooks, and provides +// BuildProjection to modify or set the ProjectionBuilder. +// This is useful for controlling which attributes are returned after an update. +// +// Used only in UpdateItem() to construct the ProjectionExpression for the request. +type ProjectionExpressionBuilder[T any] interface { + BeforeWriter[T] + BuildProjection(context.Context, *T, **expression.ProjectionBuilder) error +} + +// UpdateExpressionBuilder allows a type to inject custom logic for building +// a DynamoDB update expression during an update operation (UpdateItem). +// The interface embeds BeforeWriter[T] for pre-write hooks, and provides +// BuildUpdate to modify or set the UpdateBuilder. +// This is useful for customizing how attributes are updated, supporting +// features like atomic counters, version increments, or custom field updates. +// +// Used only in UpdateItem() to construct the UpdateExpression for the request. +type UpdateExpressionBuilder[T any] interface { + BeforeWriter[T] + BuildUpdate(context.Context, *T, **expression.UpdateBuilder) error +} + +// AutogenerateExtension provides automatic population of fields marked as +// "autogenerated" in the schemext. It supports two main features: +// - Key generation: Assigns a UUID to fields tagged with `autogenerated:key`. +// - Timestamp generation: Assigns the current time to fields tagged with `autogenerated:timestamp`. +// +// The extension is intended to be used as a BeforeWriter and UpdateExpressionBuilder +// for types managed by the DynamoDB entity manager. It inspects the schema's +// CachedFields and updates fields as needed before write operations or when +// building update expressions. +// +// Supported field types for key generation: string, []byte. +// Supported field types for timestamp generation: string, []byte, int64, uint64, time.Time. +// +// Tag options: +// - `autogenerated:key`: Generates a UUID for the field. +// - Optionally, add "always" to force regeneration even if the field is non-zero. +// - `autogenerated:timestamp`: Sets the field to the current time. +// - Optionally, add "always" to force update even if the field is non-zero. +// +// Errors are returned if the tag is misconfigured, the field type is unsupported, +// or if UUID/time generation fails. +type AutogenerateExtension[T any] struct{} + +// BeforeWrite scans all CachedFields for the item and automatically populates +// fields marked as "autogenerated" before a write operation. For each field: +// - If tagged with `autogenerated:key`, assigns a UUID if the field is zero +// or if the "always" option is present. +// - If tagged with `autogenerated:timestamp`, assigns the current time if the +// field is zero or if the "always" option is present. +// +// Returns an error if tag options are missing, misconfigured, or if assignment fails. +func (ext *AutogenerateExtension[T]) BeforeWrite(ctx context.Context, item *T) error { + cachedFields := ctx.Value(CachedFieldsKey{}).(*CachedFields) + + for _, f := range cachedFields.All() { + if !f.AutoGenerated { + continue + } + + opts, ok := f.Tag.Option("autogenerated") + if !ok || len(opts) < 1 { + return fmt.Errorf("option autogenerated expects at least 1 option, e.g. autogenerated:key or autogenerated:timestamp") + } + + switch opts[0] { + case "key": + if err := ext.processKey(item, f, opts[1:]); err != nil { + return err + } + case "timestamp": + if err := ext.processTimestamp(item, f, opts[1:]); err != nil { + return err + } + default: + return fmt.Errorf(`option autogenerated can only process key and timestamp as first argument, "%s" given`, opts[0]) + } + } + + return nil +} + +// BuildUpdate scans all CachedFields for the item and, for each field marked +// as "autogenerated", adds an update statement to the UpdateBuilder: +// - For `autogenerated:key`, sets the field to its current value (if not a key field). +// - For `autogenerated:timestamp`, sets the field to its current value or nil if zero. +// +// Returns an error if tag options are missing, misconfigured, or if assignment fails. +func (ext *AutogenerateExtension[T]) BuildUpdate(ctx context.Context, item *T, ub **expression.UpdateBuilder) error { + cachedFields := ctx.Value(CachedFieldsKey{}).(*CachedFields) + + for _, f := range cachedFields.All() { + if !f.AutoGenerated { + continue + } + + opts, ok := f.Tag.Option("autogenerated") + if !ok || len(opts) < 1 { + return fmt.Errorf("option autogenerated expects at least 1 option, e.g. autogenerated:key or autogenerated:timestamp") + } + + switch opts[0] { + case "key": + if err := ext.buildKeyUpdate(item, f, ub); err != nil { + return err + } + case "timestamp": + if err := ext.buildTimestampUpdate(item, f, ub); err != nil { + return err + } + default: + return fmt.Errorf(`option autogenerated can only process key and timestamp as first argument, "%s" given`, opts[0]) + } + } + + return nil +} + +// processKey assigns a UUID to the specified field if it is zero or if the +// "always" option is present. Supports string and []byte fields. Uses the +// field's getter/setter if provided. +// +// Returns an error if the field type is unsupported or if UUID generation fails. +func (ext *AutogenerateExtension[T]) processKey(v *T, f Field, opts []string) error { + r := reflect.ValueOf(v) + var cv reflect.Value + + if f.Tag.Getter != "" { + cv = r.MethodByName(f.Tag.Getter). + Call([]reflect.Value{})[0] + } else { + var err error + cv, err = r.Elem().FieldByIndexErr(f.Index) + if err != nil { //&& unwrap(s.options.ErrorOnMissingField) { + return err + } + } + + shouldUpdate := cv.IsZero() || slices.Contains(opts, "always") + if !shouldUpdate { + return nil + } + + s, err := smythyrand.NewUUID(cryptorand.Reader).GetUUID() + if err != nil { + return fmt.Errorf("error generating UUID: %v", err) + } + + if !cv.CanAddr() && f.Tag.Setter != "" { + cv = reflect.New(cv.Type()).Elem() + } + + switch cv.Kind() { + case reflect.String: + cv.SetString(s) + case reflect.Slice, reflect.Array: + if cv.Type().Elem().Kind() == reflect.Uint8 { + cv.SetBytes([]byte(s)) + } + default: + return fmt.Errorf("unable to assign autogenerated key to type %s, can only assign to string and []byte", cv.Type()) + } + + if f.Tag.Setter != "" { + r.MethodByName(f.Tag.Setter). + Call([]reflect.Value{ + cv, + }) + } + + return nil +} + +// processTimestamp assigns the current time to the specified field if it is +// zero or if the "always" option is present. Supports string, []byte, int64, +// uint64, and time.Time fields. Uses the field's getter/setter if provided. +// +// Returns an error if the field type is unsupported or if time assignment fails. +func (ext *AutogenerateExtension[T]) processTimestamp(v *T, f Field, opts []string) error { + r := reflect.ValueOf(v) + var cv reflect.Value + + if f.Tag.Getter != "" { + cv = r.MethodByName(f.Tag.Getter). + Call([]reflect.Value{})[0] + } else { + var err error + cv, err = r.Elem().FieldByIndexErr(f.Index) + if err != nil { //&& unwrap(s.options.ErrorOnMissingField) { + return err + } + } + + shouldUpdate := cv.IsZero() || slices.Contains(opts, "always") + if !shouldUpdate { + return nil + } + + now := time.Now() + + if !cv.CanAddr() && f.Tag.Setter != "" { + cv = reflect.New(cv.Type()).Elem() + } + + switch cv.Kind() { + case reflect.String: + cv.SetString(now.Format(time.RFC3339)) + case reflect.Slice, reflect.Array: + if cv.Type().Elem().Kind() == reflect.Uint8 { + cv.SetBytes([]byte(now.Format(time.RFC3339))) + } + case reflect.Uint64, reflect.Int64: + n := reflect.ValueOf(now.UnixNano()).Convert(cv.Type()) + cv.Set(n) + default: + if _, ok := cv.Interface().(time.Time); !ok { + return fmt.Errorf("unable to assign autogenerated key to type %s, can only assign to string, []byte and time.Time", cv.Type()) + } + + cv.Set(reflect.ValueOf(now)) + } + + if f.Tag.Setter != "" { + r.MethodByName(f.Tag.Setter). + Call([]reflect.Value{ + cv, + }) + } + + return nil +} + +// buildKeyUpdate adds an update statement to the UpdateBuilder for a key field +// marked as "autogenerated". Only non-key fields are updated. Supports string +// and []byte fields. +// +// Returns an error if the field type is unsupported. +func (ext *AutogenerateExtension[T]) buildKeyUpdate(v *T, f Field, ub **expression.UpdateBuilder) error { + var update expression.UpdateBuilder + if ub != nil && *ub != nil { + update = **ub + } + + r := reflect.ValueOf(v) + var cv reflect.Value + + if f.Tag.Getter != "" { + cv = r.MethodByName(f.Tag.Getter). + Call([]reflect.Value{})[0] + } else { + var err error + cv, err = r.Elem().FieldByIndexErr(f.Index) + if err != nil { //&& unwrap(s.options.ErrorOnMissingField) { + return err + } + } + + // pk and sk cannot be updated + if !cv.IsZero() && (f.Sort || f.Partition) { + return nil + } + + switch cv.Kind() { + case reflect.String: + update = update.Set(expression.Name(f.Name), expression.Value(cv.String())) + case reflect.Slice, reflect.Array: + if cv.Type().Elem().Kind() == reflect.Uint8 { + update = update.Set(expression.Name(f.Name), expression.Value(cv.Bytes())) + } + default: + return fmt.Errorf("unable to process update for autogenerated key to type %s, can only process to string and []byte", cv.Type()) + } + + return nil +} + +// buildTimestampUpdate adds an update statement to the UpdateBuilder for a +// timestamp field marked as "autogenerated". Sets the field to its current +// value or nil if zero. Supports string, []byte, int64, uint64, and time.Time fields. +// +// Returns an error if the field type is unsupported. +func (ext *AutogenerateExtension[T]) buildTimestampUpdate(v *T, f Field, ub **expression.UpdateBuilder) error { + var update expression.UpdateBuilder + if ub != nil && *ub != nil { + update = **ub + } + + r := reflect.ValueOf(v) + var cv reflect.Value + + if f.Tag.Getter != "" { + cv = r.MethodByName(f.Tag.Getter). + Call([]reflect.Value{})[0] + } else { + var err error + cv, err = r.Elem().FieldByIndexErr(f.Index) + if err != nil { //&& unwrap(s.options.ErrorOnMissingField) { + return err + } + } + + // pk and sk cannot be updated + if !cv.IsZero() && (f.Sort || f.Partition) { + return nil + } + + if cv.IsZero() { + update = update.Set(expression.Name(f.Name), expression.Value(nil)) + *ub = &update + + return nil + } + + switch cv.Kind() { + case reflect.String, reflect.Uint64, reflect.Int64: + update = update.Set(expression.Name(f.Name), expression.Value(cv.Interface())) + *ub = &update + case reflect.Slice, reflect.Array: + if cv.Type().Elem().Kind() == reflect.Uint8 { + update = update.Set(expression.Name(f.Name), expression.Value(cv.Interface())) + *ub = &update + break + } + fallthrough + default: + if _, ok := cv.Interface().(time.Time); !ok { + return fmt.Errorf("unable to process update for autogenerated key to type %s, can only process to string, []byte and time.Time", cv.Type()) + } + + update = update.Set(expression.Name(f.Name), expression.Value(cv.Interface())) + *ub = &update + } + + return nil +} + +// VersionExtension provides optimistic locking and version control for items +// in DynamoDB tables. It automatically manages fields marked as "version" in +// the schema, ensuring that updates only succeed if the version matches or +// the attribute does not exist. +// +// Usage: +// - Implements BeforeWriter, ConditionExpressionBuilder, and UpdateExpressionBuilder. +// - Used in UpdateItem() to build conditional expressions and increment version fields. +// +// Supported field types: string, int64, uint64. +// - For string fields, the value is parsed as an integer and incremented. +// - For int64/uint64 fields, the value is incremented directly. +// +// Errors are returned if the field type is unsupported or if string parsing fails. +type VersionExtension[T any] struct{} + +// BeforeWrite is a no-op for VersionExtension. It does not modify the item +// before writing. Always returns nil. +func (ext *VersionExtension[T]) BeforeWrite(_ context.Context, _ *T) error { return nil } + +// BuildCondition constructs a conditional expression for versioned fields. +// For each field marked as "version": +// - The condition requires that the attribute does not exist OR its value +// matches the current version. +// - This ensures that updates only succeed if the item is new or the version +// matches, providing optimistic locking. +// +// The resulting ConditionBuilder is set in cb if any version fields are present. +// Returns nil unless reflection or field access fails. +func (ext *VersionExtension[T]) BuildCondition(ctx context.Context, item *T, cb **expression.ConditionBuilder) error { + cachedFields := ctx.Value(CachedFieldsKey{}).(*CachedFields) + + r := reflect.ValueOf(item) + var condition expression.ConditionBuilder + if cb != nil && *cb != nil { + condition = **cb + } + + for _, f := range cachedFields.All() { + if !f.Version { + continue + } + + var cv reflect.Value + if f.Tag.Getter != "" { + cv = r.MethodByName(f.Tag.Getter). + Call([]reflect.Value{})[0] + } else { + cv = r.Elem().FieldByIndex(f.Index) + } + + if condition.IsSet() { + condition = condition.And( + expression.Or( + expression.AttributeNotExists(expression.Name(f.Name)), + expression.Equal( + expression.Name(f.Name), + expression.Value(cv.Interface()), + ), + ), + ) + } else { + condition = expression.Or( + expression.AttributeNotExists(expression.Name(f.Name)), + expression.Equal( + expression.Name(f.Name), + expression.Value(cv.Interface()), + ), + ) + } + } + + if condition.IsSet() { + *cb = &condition + } + + return nil +} + +// BuildUpdate constructs an update expression for versioned fields. +// For each field marked as "version": +// - If the field is a string and zero, it is initialized to "0". +// - The value is incremented by 1 (parsed as int for strings). +// - For int64/uint64 fields, the value is incremented directly. +// +// The resulting UpdateBuilder is set in ub for each version field. +// Returns an error if the field type is unsupported or if string parsing fails. +func (ext *VersionExtension[T]) BuildUpdate(ctx context.Context, item *T, ub **expression.UpdateBuilder) error { + cachedFields := ctx.Value(CachedFieldsKey{}).(*CachedFields) + + r := reflect.ValueOf(item) + var update expression.UpdateBuilder + if ub != nil && *ub != nil { + update = **ub + } + + for _, f := range cachedFields.All() { + if !f.Version { + continue + } + + var cv reflect.Value + if f.Tag.Getter != "" { + cv = r.MethodByName(f.Tag.Getter). + Call([]reflect.Value{})[0] + } else { + cv = r.Elem().FieldByIndex(f.Index) + } + + switch cv.Kind() { + case reflect.String: + if cv.IsZero() { + cv.SetString("0") + } + i, err := strconv.ParseInt(cv.String(), 10, 64) + if err != nil { + return fmt.Errorf("unable to convert string value of version field %s to number", f.Name) + } + update = update.Set(expression.Name(f.Name), expression.Value(fmt.Sprintf("%d", i+1))) + *ub = &update + case reflect.Int64: + i := cv.Int() + update = update.Set(expression.Name(f.Name), expression.Value(i+1)) + *ub = &update + case reflect.Uint64: + i := cv.Uint() + update = update.Set(expression.Name(f.Name), expression.Value(i+1)) + *ub = &update + default: + return fmt.Errorf("unable to use %s as version field %s, can only use uint64, int64 and string", cv.Type(), f.Tag.Getter) + } + } + + return nil +} + +// AtomicCounterExtension provides automatic atomic increment logic for fields +// marked as "atomiccounter" in the schemext. When used as an UpdateExpressionBuilder, +// it generates DynamoDB update expressions that increment the field value atomically +// on each update. +// +// Usage: +// - Implements BeforeWriter and UpdateExpressionBuilder. +// - Used in UpdateItem() to build update expressions for atomic counter fields. +// +// Supported field types: int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64. +// +// Tag options (on the field): +// - "atomiccounter|start=0|delta=1" +// - start: Initial value if the attribute does not exist (default: 0). +// - delta: Amount to increment on each update (default: 1). +// +// Errors are returned if the field type is unsupported, tag options are misconfigured, +// or option values cannot be parsed as integers. +type AtomicCounterExtension[T any] struct{} + +// BeforeWrite is a no-op for AtomicCounterExtension. It does not modify the item +// before writing. Always returns nil. +func (ext *AtomicCounterExtension[T]) BeforeWrite(_ context.Context, _ *T) error { return nil } + +// BuildUpdate constructs an update expression for atomic counter fields. +// For each field marked as "atomiccounter": +// - The update expression uses DynamoDB's IfNotExists and Plus functions to +// atomically increment the field by the specified delta, initializing to +// start-delta if the attribute does not exist. +// - Tag options "start" and "delta" can be provided to customize the initial +// value and increment amount. +// +// The resulting UpdateBuilder is set in ub for each atomic counter field. +// Returns an error if the field type is unsupported, tag options are misconfigured, +// or option values cannot be parsed as integers. +func (ext *AtomicCounterExtension[T]) BuildUpdate(ctx context.Context, item *T, ub **expression.UpdateBuilder) error { + cachedFields := ctx.Value(CachedFieldsKey{}).(*CachedFields) + + var update expression.UpdateBuilder + if ub != nil && *ub != nil { + update = **ub + } + + for _, f := range cachedFields.All() { + if !f.AtomicCounter { + continue + } + + if f.Type.Kind() < reflect.Int || f.Type.Kind() > reflect.Uint64 { + return fmt.Errorf("atomic counter field %s has unsupported type %s", f.Name, f.Type.Kind()) + } + + dflt := 0 + delta := 1 + + if opts, ok := f.Option("atomiccounter"); ok { + for _, opt := range opts { + parts := strings.Split(opt, "=") + if len(parts) != 2 { + return fmt.Errorf(`field %s has the tag atomiccounter missonfigured, options must look like "atomiccounter|start=0|delta=1", "%s" given`, f.Name, opt) + } + + val, err := strconv.Atoi(parts[1]) + if err != nil { + return err + } + + switch parts[0] { + case "start": + dflt = val + case "delta": + delta = val + default: + return fmt.Errorf(`unknown options "%s" passed to "atomiccounter" on field %s`, parts[0], f.Name) + } + } + } + + update = update.Set( + expression.Name(f.Name), + expression.Plus( + expression.IfNotExists( + expression.Name(f.Name), + expression.Value(dflt-delta), + ), + expression.Value(delta), + ), + ) + *ub = &update + } + + return nil +} diff --git a/feature/dynamodb/entitymanager/extension_registry.go b/feature/dynamodb/entitymanager/extension_registry.go new file mode 100644 index 00000000000..16808eb0513 --- /dev/null +++ b/feature/dynamodb/entitymanager/extension_registry.go @@ -0,0 +1,91 @@ +package entitymanager + +// ExtensionRegistry manages a set of extension hooks for enhanced DynamoDB +// operations on a given type T. It allows registration of pre- and post-processing +// logic for read and write operations, enabling features such as automatic field +// population, versioning, atomic counters, and custom business logic. +// +// Extensions are grouped by operation type: +// - beforeReaders: Invoked before reading an item (e.g., GetItem). +// - afterReaders: Invoked after reading an item. +// - beforeWriters: Invoked before writing an item (e.g., PutItem, UpdateItem). +// - afterWriters: Invoked after writing an item. +// +// The registry supports method chaining for extension registration. +// DefaultExtensionRegistry provides a registry pre-populated with common extensions. +type ExtensionRegistry[T any] struct { + // GetItem + beforeReaders []BeforeReader[T] + afterReaders []AfterReader[T] + // PutItem | UpdateItem + beforeWriters []BeforeWriter[T] + afterWriters []AfterWriter[T] +} + +// AddBeforeReader registers a BeforeReader extension to be invoked before +// reading an item. Returns the registry for method chaining. +func (er *ExtensionRegistry[T]) AddBeforeReader(br BeforeReader[T]) *ExtensionRegistry[T] { + er.beforeReaders = append(er.beforeReaders, br) + + return er +} + +// AddAfterReader registers an AfterReader extension to be invoked after +// reading an item. Returns the registry for method chaining. +func (er *ExtensionRegistry[T]) AddAfterReader(ar AfterReader[T]) *ExtensionRegistry[T] { + er.afterReaders = append(er.afterReaders, ar) + + return er +} + +// AddBeforeWriter registers a BeforeWriter extension to be invoked before +// writing an item. Returns the registry for method chaining. +func (er *ExtensionRegistry[T]) AddBeforeWriter(bw BeforeWriter[T]) *ExtensionRegistry[T] { + er.beforeWriters = append(er.beforeWriters, bw) + + return er +} + +// AddAfterWriter registers an AfterWriter extension to be invoked after +// writing an item. Returns the registry for method chaining. +func (er *ExtensionRegistry[T]) AddAfterWriter(aw AfterWriter[T]) *ExtensionRegistry[T] { + er.afterWriters = append(er.afterWriters, aw) + + return er +} + +// Clone creates a new ExtensionRegistry containing copies of all registered +// extensions for type T. The returned registry has independent extension slices, +// so further modifications (adding/removing extensions) do not affect the original. +// +// Note: The extensions themselves are not deep-copied; only the slice references +// are duplicated. If extensions maintain internal state, that state will be shared. +// +// Returns a pointer to the new ExtensionRegistry. +func (er ExtensionRegistry[T]) Clone() *ExtensionRegistry[T] { + out := &ExtensionRegistry[T]{} + + out.beforeReaders = append(out.beforeReaders, er.beforeReaders...) + out.afterReaders = append(out.afterReaders, er.afterReaders...) + out.beforeWriters = append(out.beforeWriters, er.beforeWriters...) + out.afterWriters = append(out.afterWriters, er.afterWriters...) + + return out +} + +// DefaultExtensionRegistry returns a new ExtensionRegistry pre-populated with +// common beforeWriter extensions: AutogenerateExtension, AtomicCounterExtension, +// and VersionExtension. These provide automatic key/timestamp population, +// atomic counter updates, and optimistic versioning for write operations. +func DefaultExtensionRegistry[T any]() *ExtensionRegistry[T] { + out := &ExtensionRegistry[T]{} + + out.beforeWriters = append( + out.beforeWriters, + &AutogenerateExtension[T]{}, + &AtomicCounterExtension[T]{}, + &VersionExtension[T]{}, + ) + + return out +} diff --git a/feature/dynamodb/entitymanager/extension_registry_test.go b/feature/dynamodb/entitymanager/extension_registry_test.go new file mode 100644 index 00000000000..1258e0d5137 --- /dev/null +++ b/feature/dynamodb/entitymanager/extension_registry_test.go @@ -0,0 +1,60 @@ +package entitymanager + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" +) + +type dummyExtension struct{} + +func (*dummyExtension) IsExtension() {} + +func (*dummyExtension) BeforeRead(context.Context, *dummyExtension) error { return nil } +func (*dummyExtension) AfterRead(context.Context, *dummyExtension) error { return nil } +func (*dummyExtension) BeforeWrite(context.Context, *dummyExtension) error { return nil } +func (*dummyExtension) AfterWrite(context.Context, *dummyExtension) error { return nil } +func (*dummyExtension) BeforeQuery(context.Context, *dynamodb.QueryInput) error { + return nil +} +func (*dummyExtension) AfterQuery(context.Context, []dummyExtension) error { return nil } +func (*dummyExtension) BeforeScan(context.Context, *dynamodb.ScanInput) error { return nil } +func (*dummyExtension) AfterScan(context.Context, []dummyExtension) error { return nil } + +func TestExtensionRegistry(t *testing.T) { + er := &ExtensionRegistry[dummyExtension]{} + er.AddBeforeReader(&dummyExtension{}) + er.AddAfterReader(&dummyExtension{}) + er.AddBeforeWriter(&dummyExtension{}) + er.AddAfterWriter(&dummyExtension{}) + //er.AddBeforeScanner(&dummyExtension{}) + //er.AddAfterScanner(&dummyExtension{}) + //er.AddBeforeQuerier(&dummyExtension{}) + //er.AddAfterQuerier(&dummyExtension{}) + + if len(er.beforeReaders) != 1 { + t.Errorf("beforeReaders expected to be 1, got %d", len(er.beforeReaders)) + } + if len(er.afterReaders) != 1 { + t.Errorf("afterReaders expected to be 1, got %d", len(er.afterReaders)) + } + if len(er.beforeWriters) != 1 { + t.Errorf("beforeWriters expected to be 1, got %d", len(er.beforeWriters)) + } + if len(er.afterWriters) != 1 { + t.Errorf("afterWriters expected to be 1, got %d", len(er.afterWriters)) + } + //if len(er.beforeQueriers) != 1 { + // t.Errorf("beforeQueriers expected to be 1, got %d", len(er.beforeQueriers)) + //} + //if len(er.afterQueriers) != 1 { + // t.Errorf("afterQueriers expected to be 1, got %d", len(er.afterQueriers)) + //} + //if len(er.beforeScanners) != 1 { + // t.Errorf("beforeScanners expected to be 1, got %d", len(er.beforeScanners)) + //} + //if len(er.afterScanners) != 1 { + // t.Errorf("afterScanners expected to be 1, got %d", len(er.afterScanners)) + //} +} diff --git a/feature/dynamodb/entitymanager/extension_test.go b/feature/dynamodb/entitymanager/extension_test.go new file mode 100644 index 00000000000..1d36378cfdc --- /dev/null +++ b/feature/dynamodb/entitymanager/extension_test.go @@ -0,0 +1,357 @@ +package entitymanager + +import ( + "reflect" + "strconv" + "testing" + "time" +) + +func TestAutogenerateExtension(t *testing.T) { + +} + +type autogenerateTestStruct struct { + // props to work with + InputString string + InputByteArray []byte + InputNumber int64 + privateInputString string + privateInputByteArray []byte + privateInputNumber int64 + InputTime time.Time + privateInputTime time.Time + + // test stuff + field Field + error bool + validate func(autogenerateTestStruct) bool +} + +func (a *autogenerateTestStruct) GetPrivateInputString() string { + return a.privateInputString +} + +func (a *autogenerateTestStruct) SetPrivateInputString(s string) { + a.privateInputString = s +} + +func (a *autogenerateTestStruct) GetPrivateInputByteArray() []byte { + return a.privateInputByteArray +} + +func (a *autogenerateTestStruct) SetPrivateInputByteArray(s []byte) { + a.privateInputByteArray = s +} + +func (a *autogenerateTestStruct) GetPrivateInputNumber() int64 { + return a.privateInputNumber +} + +func (a *autogenerateTestStruct) SetPrivateInputNumber(s int64) { + a.privateInputNumber = s +} + +func (a *autogenerateTestStruct) GetPrivateInputTime() time.Time { + return a.privateInputTime +} + +func (a *autogenerateTestStruct) SetPrivateInputTime(s time.Time) { + a.privateInputTime = s +} + +func TestAutogenerateExtensionProcessKey(t *testing.T) { + cases := []autogenerateTestStruct{ + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"key"}, + }, + }, + Name: "inputString", + NameFromTag: false, + Index: []int{0}, + Type: reflect.TypeFor[string](), + }, + validate: func(a autogenerateTestStruct) bool { + return len(a.InputString) > 0 + }, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"key"}, + }, + }, + Name: "inputByteArray", + NameFromTag: false, + Index: []int{1}, + Type: reflect.TypeFor[[]byte](), + }, + validate: func(a autogenerateTestStruct) bool { + return len(a.InputByteArray) > 0 + }, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"key"}, + }, + }, + Name: "inputNumber", + NameFromTag: false, + Index: []int{2}, + Type: reflect.TypeFor[int64](), + }, + error: true, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"key"}, + }, + Getter: "GetPrivateInputString", + Setter: "SetPrivateInputString", + }, + Name: "inputString", + NameFromTag: false, + Index: []int{3}, + Type: reflect.TypeFor[string](), + }, + validate: func(a autogenerateTestStruct) bool { + return len(a.privateInputString) > 0 + }, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"key"}, + }, + Getter: "GetPrivateInputByteArray", + Setter: "SetPrivateInputByteArray", + }, + Name: "inputByteArray", + NameFromTag: false, + Index: []int{4}, + Type: reflect.TypeFor[[]byte](), + }, + validate: func(a autogenerateTestStruct) bool { + return len(a.privateInputByteArray) > 0 + }, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"key"}, + }, + Getter: "GetPrivateInputNumber", + Setter: "SetPrivateInputNumber", + }, + Name: "inputNumber", + NameFromTag: false, + Index: []int{5}, + Type: reflect.TypeFor[int64](), + }, + error: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + ext := AutogenerateExtension[autogenerateTestStruct]{} + err := ext.processKey(&c, c.field, []string{}) + + if !c.error && err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if c.error && err == nil { + t.Fatal("expected error") + } + + if c.validate != nil && !c.validate(c) { + t.Fatalf("failed to validate scenario") + } + }) + } +} + +func TestAutogenerateExtensionProcessTimestamp(t *testing.T) { + cases := []autogenerateTestStruct{ + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"timestamp"}, + }, + }, + Name: "inputString", + NameFromTag: false, + Index: []int{0}, + Type: reflect.TypeFor[string](), + }, + validate: func(a autogenerateTestStruct) bool { + return len(a.InputString) > 0 + }, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"timestamp"}, + }, + }, + Name: "inputByteArray", + NameFromTag: false, + Index: []int{1}, + Type: reflect.TypeFor[[]byte](), + }, + validate: func(a autogenerateTestStruct) bool { + return len(a.InputByteArray) > 0 + }, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"timestamp"}, + }, + }, + Name: "inputNumber", + NameFromTag: false, + Index: []int{2}, + Type: reflect.TypeFor[int64](), + }, + validate: func(a autogenerateTestStruct) bool { + return a.InputNumber > 0 + }, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"timestamp"}, + }, + Getter: "GetPrivateInputString", + Setter: "SetPrivateInputString", + }, + Name: "inputString", + NameFromTag: false, + Index: []int{3}, + Type: reflect.TypeFor[string](), + }, + validate: func(a autogenerateTestStruct) bool { + return len(a.privateInputString) > 0 + }, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"timestamp"}, + }, + Getter: "GetPrivateInputByteArray", + Setter: "SetPrivateInputByteArray", + }, + Name: "inputByteArray", + NameFromTag: false, + Index: []int{4}, + Type: reflect.TypeFor[[]byte](), + }, + validate: func(a autogenerateTestStruct) bool { + return len(a.privateInputByteArray) > 0 + }, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"timestamp"}, + }, + Getter: "GetPrivateInputNumber", + Setter: "SetPrivateInputNumber", + }, + Name: "inputNumber", + NameFromTag: false, + Index: []int{5}, + Type: reflect.TypeFor[int64](), + }, + validate: func(a autogenerateTestStruct) bool { + return a.privateInputNumber > 0 + }, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"timestamp"}, + }, + }, + Name: "inputTime", + NameFromTag: false, + Index: []int{6}, + Type: reflect.TypeFor[int64](), + }, + validate: func(a autogenerateTestStruct) bool { + return !a.InputTime.IsZero() + }, + }, + { + field: Field{ + Tag: Tag{ + AutoGenerated: true, + Options: map[string][]string{ + "autogenerated": {"timestamp"}, + }, + Getter: "GetPrivateInputTime", + Setter: "SetPrivateInputTime", + }, + Name: "privateInputTime", + NameFromTag: false, + Index: []int{7}, + Type: reflect.TypeFor[int64](), + }, + validate: func(a autogenerateTestStruct) bool { + return !a.privateInputTime.IsZero() + }, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + ext := AutogenerateExtension[autogenerateTestStruct]{} + err := ext.processTimestamp(&c, c.field, []string{}) + + if !c.error && err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if c.error && err == nil { + t.Fatal("expected error") + } + + if c.validate != nil && !c.validate(c) { + t.Fatalf("failed to validate scenario") + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/field.go b/feature/dynamodb/entitymanager/field.go new file mode 100644 index 00000000000..b3d80109e2e --- /dev/null +++ b/feature/dynamodb/entitymanager/field.go @@ -0,0 +1,285 @@ +package entitymanager + +import ( + "reflect" + "sort" +) + +// Field represents metadata about a struct field for schema mapping. +// It includes the field's name, type, index path, and tag information. +// Used internally for encoding/decoding Go structs to DynamoDB items. +type Field struct { + Tag // Parsed struct tag information + + Name string // Field name (possibly overridden by tag) + NameFromTag bool // True if the name was set by a struct tag + + Index []int // Index path for reflect.Value.FieldByIndex + Type reflect.Type // Field type +} + +func buildField(pIdx []int, i int, sf reflect.StructField, fieldTag Tag) Field { + f := Field{ + Name: sf.Name, + Type: sf.Type, + Tag: fieldTag, + } + if len(fieldTag.Name) != 0 { + f.NameFromTag = true + f.Name = fieldTag.Name + } + + f.Index = make([]int, len(pIdx)+1) + copy(f.Index, pIdx) + f.Index[len(pIdx)] = i + + return f +} + +type structFieldOptions struct { + // Support other custom struct Tag keys, such as `yaml`, `json`, or `toml`. + // Note that values provided with a custom TagKey must also be supported + // by the (un)marshalers in this package. + // + // Tag key `dynamodbav` will always be read, but if custom Tag key + // conflicts with `dynamodbav` the custom Tag key value will be used. + TagKey string +} + +// unionStructFields returns a list of CachedFields for the given type. Type info is cached +// to avoid repeated calls into the reflect package +func unionStructFields(t reflect.Type, opts structFieldOptions) *CachedFields { + key := fieldCacheKey{ + typ: t, + opts: opts, + } + + if cached, ok := fieldCache.Load(key); ok { + return cached + } + + f := enumFields(t, opts) + sort.Sort(fieldsByName(f)) + f = visibleFields(f) + + fs := &CachedFields{ + fields: f, + fieldsByName: make(map[string]int, len(f)), + } + for i, f := range fs.fields { + fs.fieldsByName[f.Name] = i + } + + cached, _ := fieldCache.LoadOrStore(key, fs) + return cached +} + +// enumFields will recursively iterate through a structure and its nested +// anonymous CachedFields. +// +// Based on the enoding/json struct Field enumeration of the Go Stdlib +// https://golang.org/src/encoding/json/encode.go typeField func. +func enumFields(t reflect.Type, opts structFieldOptions) []Field { + // Fields to explore + current := []Field{} + next := []Field{{Type: t}} + + // count of queued names + count := map[reflect.Type]int{} + nextCount := map[reflect.Type]int{} + + visited := map[reflect.Type]struct{}{} + fields := []Field{} + + for len(next) > 0 { + current, next = next, current[:0] + count, nextCount = nextCount, map[reflect.Type]int{} + + for _, f := range current { + if _, ok := visited[f.Type]; ok { + continue + } + visited[f.Type] = struct{}{} + + for i := 0; i < f.Type.NumField(); i++ { + sf := f.Type.Field(i) + + fieldTag := Tag{} + fieldTag.parseAVTag(sf.Tag) + // Because MarshalOptions.TagKey must be explicitly set. + if opts.TagKey != "" && opts.TagKey != defaultTagKey { + fieldTag.parseStructTag(opts.TagKey, sf.Tag) + } + + if sf.PkgPath != "" && !sf.Anonymous && fieldTag.Getter == "" && fieldTag.Setter == "" { + // Ignore unexported and non-anonymous CachedFields + // unexported but anonymous Field may still be used if + // the type has exported nested CachedFields + // or if they have a getter and setter + continue + } + + if fieldTag.Ignore { + continue + } + + ft := sf.Type + if ft.Name() == "" && ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + + structField := buildField(f.Index, i, sf, fieldTag) + structField.Type = ft + + if !sf.Anonymous || fieldTag.Name != "" || ft.Kind() != reflect.Struct { + fields = append(fields, structField) + if count[f.Type] > 1 { + // If there were multiple instances, add a second, + // so that the annihilation code will see a duplicate. + // It only cares about the distinction between 1 or 2, + // so don't bother generating any more copies. + fields = append(fields, structField) + } + continue + } + + // Record new anon struct to explore next round + nextCount[ft]++ + if nextCount[ft] == 1 { + next = append(next, structField) + } + } + } + } + + return fields +} + +// visibleFields will return a slice of CachedFields which are visible based on +// Go's standard visiblity rules with the exception of ties being broken +// by depth and struct Tag naming. +// +// Based on the enoding/json Field filtering of the Go Stdlib +// https://golang.org/src/encoding/json/encode.go typeField func. +func visibleFields(fields []Field) []Field { + // Delete all CachedFields that are hidden by the Go rules for embedded CachedFields, + // except that CachedFields with JSON tags are promoted. + + // The CachedFields are sorted in primary order of name, secondary order + // of Field index length. Loop over names; for each name, delete + // hidden CachedFields by choosing the one dominant Field that survives. + out := fields[:0] + for advance, i := 0, 0; i < len(fields); i += advance { + // One iteration per name. + // Find the sequence of CachedFields with the name of this first Field. + fi := fields[i] + name := fi.Name + for advance = 1; i+advance < len(fields); advance++ { + fj := fields[i+advance] + if fj.Name != name { + break + } + } + if advance == 1 { // Only one Field with this name + out = append(out, fi) + continue + } + dominant, ok := dominantField(fields[i : i+advance]) + if ok { + out = append(out, dominant) + } + } + + fields = out + sort.Sort(fieldsByIndex(fields)) + + return fields +} + +// dominantField looks through the CachedFields, all of which are known to +// have the same name, to find the single Field that dominates the +// others using Go's embedding rules, modified by the presence of +// JSON tags. If there are multiple top-level CachedFields, the boolean +// will be false: This condition is an error in Go and we skip all +// the CachedFields. +// +// Based on the enoding/json Field filtering of the Go Stdlib +// https://golang.org/src/encoding/json/encode.go dominantField func. +func dominantField(fields []Field) (Field, bool) { + // The CachedFields are sorted in increasing index-length order. The winner + // must therefore be one with the shortest index length. Drop all + // longer entries, which is easy: just truncate the slice. + length := len(fields[0].Index) + tagged := -1 // Index of first tagged Field. + for i, f := range fields { + if len(f.Index) > length { + fields = fields[:i] + break + } + if f.NameFromTag { + if tagged >= 0 { + // Multiple tagged CachedFields at the same level: conflict. + // Return no Field. + return Field{}, false + } + tagged = i + } + } + if tagged >= 0 { + return fields[tagged], true + } + // All remaining CachedFields have the same length. If there's more than one, + // we have a conflict (two CachedFields named "X" at the same level) and we + // return no Field. + if len(fields) > 1 { + return Field{}, false + } + return fields[0], true +} + +// fieldsByName sorts Field by name, breaking ties with depth, +// then breaking ties with "name came from json Tag", then +// breaking ties with index sequence. +// +// Based on the enoding/json Field filtering of the Go Stdlib +// https://golang.org/src/encoding/json/encode.go fieldsByName type. +type fieldsByName []Field + +func (x fieldsByName) Len() int { return len(x) } + +func (x fieldsByName) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +func (x fieldsByName) Less(i, j int) bool { + if x[i].Name != x[j].Name { + return x[i].Name < x[j].Name + } + if len(x[i].Index) != len(x[j].Index) { + return len(x[i].Index) < len(x[j].Index) + } + if x[i].NameFromTag != x[j].NameFromTag { + return x[i].NameFromTag + } + return fieldsByIndex(x).Less(i, j) +} + +// fieldsByIndex sorts Field by index sequence. +// +// Based on the enoding/json Field filtering of the Go Stdlib +// https://golang.org/src/encoding/json/encode.go fieldsByIndex type. +type fieldsByIndex []Field + +func (x fieldsByIndex) Len() int { return len(x) } + +func (x fieldsByIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +func (x fieldsByIndex) Less(i, j int) bool { + for k, xik := range x[i].Index { + if k >= len(x[j].Index) { + return false + } + if xik != x[j].Index[k] { + return xik < x[j].Index[k] + } + } + return len(x[i].Index) < len(x[j].Index) +} diff --git a/feature/dynamodb/entitymanager/field_cache.go b/feature/dynamodb/entitymanager/field_cache.go new file mode 100644 index 00000000000..eef4e67f553 --- /dev/null +++ b/feature/dynamodb/entitymanager/field_cache.go @@ -0,0 +1,56 @@ +package entitymanager + +import ( + "reflect" + "strings" + "sync" +) + +var fieldCache = &fieldCacher{} + +type fieldCacheKey struct { + typ reflect.Type + opts structFieldOptions +} + +type fieldCacher struct { + cache sync.Map +} + +func (c *fieldCacher) Load(key fieldCacheKey) (*CachedFields, bool) { + if v, ok := c.cache.Load(key); ok { + return v.(*CachedFields), true + } + return nil, false +} + +func (c *fieldCacher) LoadOrStore(key fieldCacheKey, fs *CachedFields) (*CachedFields, bool) { + v, ok := c.cache.LoadOrStore(key, fs) + return v.(*CachedFields), ok +} + +// CachedFields holds a slice of Field metadata and a map for fast lookup by field name. +// Used to cache struct field information for efficient encoding/decoding. +type CachedFields struct { + fields []Field + fieldsByName map[string]int +} + +// All returns all cached Field metadata for the struct. +func (f *CachedFields) All() []Field { + return f.fields +} + +// FieldByName returns the Field metadata for the given name, case-insensitive. +// Returns the Field and true if found, or a zero Field and false otherwise. +func (f *CachedFields) FieldByName(name string) (Field, bool) { + if i, ok := f.fieldsByName[name]; ok { + return f.fields[i], ok + } + for _, f := range f.fields { + if strings.EqualFold(f.Name, name) { + return f, true + } + } + return Field{}, false +} diff --git a/feature/dynamodb/entitymanager/field_test.go b/feature/dynamodb/entitymanager/field_test.go new file mode 100644 index 00000000000..a1ef7d8606d --- /dev/null +++ b/feature/dynamodb/entitymanager/field_test.go @@ -0,0 +1,247 @@ +package entitymanager + +import ( + "fmt" + "reflect" + "testing" +) + +type testUnionValues struct { + Name string + Value interface{} +} + +type unionSimple struct { + A int + B string + C []string +} + +type unionComplex struct { + unionSimple + A int +} + +type unionTagged struct { + A int `dynamodbav:"ddbav" json:"A" taga:"TagA" tagb:"TagB"` +} + +type unionTaggedComplex struct { + unionSimple + unionTagged + B string +} + +func TestUnionStructFields(t *testing.T) { + origFieldCache := fieldCache + defer func() { fieldCache = origFieldCache }() + + fieldCache = &fieldCacher{} + + var cases = map[string]struct { + in interface{} + opts structFieldOptions + expect []testUnionValues + }{ + "simple input": { + in: unionSimple{1, "2", []string{"abc"}}, + opts: structFieldOptions{TagKey: "json"}, + expect: []testUnionValues{ + {"A", 1}, + {"B", "2"}, + {"C", []string{"abc"}}, + }, + }, + "nested struct": { + in: unionComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + A: 2, + }, + opts: structFieldOptions{TagKey: "json"}, + expect: []testUnionValues{ + {"B", "2"}, + {"C", []string{"abc"}}, + {"A", 2}, + }, + }, + "with TagKey unset": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"ddbav", 3}, + {"B", "3"}, + }, + }, + "with TagKey json": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + opts: structFieldOptions{TagKey: "json"}, + expect: []testUnionValues{ + {"C", []string{"abc"}}, + {"A", 3}, + {"B", "3"}, + }, + }, + "with TagKey taga": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + opts: structFieldOptions{TagKey: "taga"}, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"TagA", 3}, + {"B", "3"}, + }, + }, + "with TagKey tagb": { + in: unionTaggedComplex{ + unionSimple: unionSimple{1, "2", []string{"abc"}}, + unionTagged: unionTagged{3}, + B: "3", + }, + opts: structFieldOptions{TagKey: "tagb"}, + expect: []testUnionValues{ + {"A", 1}, + {"C", []string{"abc"}}, + {"TagB", 3}, + {"B", "3"}, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + v := reflect.ValueOf(c.in) + + fields := unionStructFields(v.Type(), c.opts) + for i, f := range fields.All() { + expected := c.expect[i] + if e, a := expected.Name, f.Name; e != a { + t.Errorf("%d expect %v, got %v, %v", i, e, a, f) + } + actual := v.FieldByIndex(f.Index).Interface() + if e, a := expected.Value, actual; !reflect.DeepEqual(e, a) { + t.Errorf("%d expect %v, got %v, %v", i, e, a, f) + } + } + }) + } +} + +func TestCachedFields(t *testing.T) { + type myStruct struct { + Dog int `tag1:"rabbit" tag2:"cow" tag3:"horse"` + CAT string + bird bool + } + + cases := map[string][]struct { + Name string + FieldName string + Found bool + }{ + "": { + {"Dog", "Dog", true}, + {"dog", "Dog", true}, + {"DOG", "Dog", true}, + {"Yorkie", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag1": { + {"rabbit", "rabbit", true}, + {"Rabbit", "rabbit", true}, + {"cow", "", false}, + {"Cow", "", false}, + {"horse", "", false}, + {"Horse", "", false}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag2": { + {"rabbit", "", false}, + {"Rabbit", "", false}, + {"cow", "cow", true}, + {"Cow", "cow", true}, + {"horse", "", false}, + {"Horse", "", false}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + "tag3": { + {"rabbit", "", false}, + {"Rabbit", "", false}, + {"cow", "", false}, + {"Cow", "", false}, + {"horse", "horse", true}, + {"Horse", "horse", true}, + {"Dog", "", false}, + {"dog", "", false}, + {"DOG", "", false}, + {"Cat", "CAT", true}, + {"cat", "CAT", true}, + {"CAT", "CAT", true}, + {"tiger", "", false}, + {"bird", "", false}, + }, + } + + for tagKey, cs := range cases { + tagKey := tagKey + cs := cs + for _, c := range cs { + name := tagKey + if name == "" { + name = "none" + } + + c := c + t.Run(fmt.Sprintf("%s/%s", name, c.Name), func(t *testing.T) { + t.Parallel() + + fields := unionStructFields(reflect.TypeOf(myStruct{}), structFieldOptions{ + TagKey: tagKey, + }) + + const expectedNumFields = 2 + if numFields := len(fields.All()); numFields != expectedNumFields { + t.Errorf("expect %v CachedFields, got %d", expectedNumFields, numFields) + } + + f, found := fields.FieldByName(c.Name) + if found != c.Found { + t.Errorf("expect %v found, got %v", c.Found, found) + } + if found && f.Name != c.FieldName { + t.Errorf("expect %v Field name, got %s", c.FieldName, f.Name) + } + }) + } + } +} diff --git a/feature/dynamodb/entitymanager/go.mod b/feature/dynamodb/entitymanager/go.mod new file mode 100644 index 00000000000..58b64886c8e --- /dev/null +++ b/feature/dynamodb/entitymanager/go.mod @@ -0,0 +1,27 @@ +module github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager + +go 1.23 + +require ( + github.com/aws/aws-sdk-go-v2 v1.39.4 + github.com/aws/aws-sdk-go-v2/config v1.31.15 + github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.19 + github.com/aws/aws-sdk-go-v2/service/dynamodb v1.52.2 + github.com/aws/smithy-go v1.23.1 +) + +require ( + github.com/aws/aws-sdk-go-v2/credentials v1.18.19 // indirect + github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.19 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.11 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.11 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.8 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.9 // indirect +) diff --git a/feature/dynamodb/entitymanager/go.sum b/feature/dynamodb/entitymanager/go.sum new file mode 100644 index 00000000000..d3baef75c3f --- /dev/null +++ b/feature/dynamodb/entitymanager/go.sum @@ -0,0 +1,36 @@ +github.com/aws/aws-sdk-go-v2 v1.39.4 h1:qTsQKcdQPHnfGYBBs+Btl8QwxJeoWcOcPcixK90mRhg= +github.com/aws/aws-sdk-go-v2 v1.39.4/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM= +github.com/aws/aws-sdk-go-v2/config v1.31.15 h1:gE3M4xuNXfC/9bG4hyowGm/35uQTi7bUKeYs5e/6uvU= +github.com/aws/aws-sdk-go-v2/config v1.31.15/go.mod h1:HvnvGJoE2I95KAIW8kkWVPJ4XhdrlvwJpV6pEzFQa8o= +github.com/aws/aws-sdk-go-v2/credentials v1.18.19 h1:Jc1zzwkSY1QbkEcLujwqRTXOdvW8ppND3jRBb/VhBQc= +github.com/aws/aws-sdk-go-v2/credentials v1.18.19/go.mod h1:DIfQ9fAk5H0pGtnqfqkbSIzky82qYnGvh06ASQXXg6A= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.19 h1:HmfJKLqMk6nVmyZrScNXVlnqfhIeiAcsJozT1Md8pBI= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.20.19/go.mod h1:BVQAm94IwIMmbNGwd7inlFczhZl75gIQWK7SejQPSRA= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.19 h1:zgzPsusBAei9dD+PUb+26W7Ju6q/MM8+SrSCL7abJ54= +github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.8.19/go.mod h1:oQ3PiGmB6gdUDAO6y4XAOVkG/biM7qMxI/522eZLjMc= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11 h1:X7X4YKb+c0rkI6d4uJ5tEMxXgCZ+jZ/D6mvkno8c8Uw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.11/go.mod h1:EqM6vPZQsZHYvC4Cai35UDg/f5NCEU+vp0WfbVqVcZc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11 h1:7AANQZkF3ihM8fbdftpjhken0TP9sBzFbV/Ze/Y4HXA= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.11/go.mod h1:NTF4QCGkm6fzVwncpkFQqoquQyOolcyXfbpC98urj+c= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11 h1:ShdtWUZT37LCAA4Mw2kJAJtzaszfSHFb5n25sdcv4YE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.11/go.mod h1:7bUb2sSr2MZ3M/N+VyETLTQtInemHXb/Fl3s8CLzm0Y= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.52.2 h1:v63QYOleHhBT1SctUsl4RXH+yjYuxQzpGxFRfjCmXBc= +github.com/aws/aws-sdk-go-v2/service/dynamodb v1.52.2/go.mod h1:OU+zHNgIjScCe8j2GAZ7uEWVMH3UupqAp2c2gpyckEE= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.0 h1:ccmQULuINm6Yj9ynQY5+6rnDnGXCVQnWh5aqVDec+K8= +github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.32.0/go.mod h1:kPSrLRdnPrs1oEl7B5f6DInj2kpv3ePyh/Ow22zXlrw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.11 h1:E+Q3COWEOkzzxo3kxG6zUskB3qsNMG/+UWbuREq5b9M= +github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.11.11/go.mod h1:p2NzdJjY5n+i+BAf9iw5jZRURdplXLX47IRB8LP2AgQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.11 h1:GpMf3z2KJa4RnJ0ew3Hac+hRFYLZ9DDjfgXjuW+pB54= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.11/go.mod h1:6MZP3ZI4QQsgUCFTwMZA2V0sEriNQ8k2hmoHF3qjimQ= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.8 h1:M5nimZmugcZUO9wG7iVtROxPhiqyZX6ejS1lxlDPbTU= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.8/go.mod h1:mbef/pgKhtKRwrigPPs7SSSKZgytzP8PQ6P6JAAdqyM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.3 h1:S5GuJZpYxE0lKeMHKn+BRTz6PTFpgThyJ+5mYfux7BM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.3/go.mod h1:X4OF+BTd7HIb3L+tc4UlWHVrpgwZZIVENU15pRDVTI0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.9 h1:Ekml5vGg6sHSZLZJQJagefnVe6PmqC2oiRkBq4F7fU0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.9/go.mod h1:/e15V+o1zFHWdH3u7lpI3rVBcxszktIKuHKCY2/py+k= +github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= +github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= diff --git a/feature/dynamodb/entitymanager/index.go b/feature/dynamodb/entitymanager/index.go new file mode 100644 index 00000000000..009496eb044 --- /dev/null +++ b/feature/dynamodb/entitymanager/index.go @@ -0,0 +1,10 @@ +package entitymanager + +// Index represents metadata for a DynamoDB index, including its name and type (global/local, partition/sort). +type Index struct { + Name string // Index name + Global bool // True if the index is a global secondary index + Local bool // True if the index is a local secondary index + Partition bool // True if the index is a partition key + Sort bool // True if the index is a sort key +} diff --git a/feature/dynamodb/entitymanager/item_result.go b/feature/dynamodb/entitymanager/item_result.go new file mode 100644 index 00000000000..f1ff6f64b73 --- /dev/null +++ b/feature/dynamodb/entitymanager/item_result.go @@ -0,0 +1,23 @@ +package entitymanager + +// ItemResult represents the result of a DynamoDB operation that returns an item or an error. +// Used in iterators for batch, scan, and query operations to convey either a successfully decoded item or an error. +type ItemResult[T any] struct { + item T // The decoded item, if successful + table string + err error // The error encountered, if any +} + +func (it *ItemResult[T]) Table() string { + return it.table +} + +// Item returns the decoded item, or nil if an error occurred. +func (it *ItemResult[T]) Item() T { + return it.item +} + +// Error returns the error encountered during the operation, or nil if successful. +func (it *ItemResult[T]) Error() error { + return it.err +} diff --git a/feature/dynamodb/entitymanager/map.go b/feature/dynamodb/entitymanager/map.go new file mode 100644 index 00000000000..331f9e6674a --- /dev/null +++ b/feature/dynamodb/entitymanager/map.go @@ -0,0 +1,52 @@ +package entitymanager + +import ( + "fmt" + "reflect" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// Map is a convenience type representing a DynamoDB item as a map of attribute names to AttributeValue. +// It is used for constructing, manipulating, and passing items to and from DynamoDB operations. +type Map map[string]types.AttributeValue + +func (m Map) String() string { + buff := strings.Builder{} + buff.WriteString("Map{") + for key, value := range m { + buff.WriteString(fmt.Sprintf("%s: %#+v, ", key, value)) + } + buff.WriteString("}") + + return buff.String() +} + +// With function takes a key string and a value of any kind and add it to the map as the corresponding AttributeValueMemberX type +// - []byte becomes types.AttributeValueMemberB +// - bool becomes types.AttributeValueMemberBOOL +// - [][]byte becomes types.AttributeValueMemberBS +// - []any becomes types.AttributeValueMemberL +// - map[any]any becomes types.AttributeValueMemberM +// - any type of int or float becomes types.AttributeValueMemberN +// - any type of int or float array ([5]type{...}) becomes types.AttributeValueMemberNS +// - nil becomes types.AttributeValueMemberNULL{Value: true} +// - string becomes types.AttributeValueMemberS +// - [3]string becomes types.AttributeValueMemberSS +// Note: [3] and [5] are not actual values we search for, they are just examples to illustrate go arrays vs go slices +func (m Map) With(key string, value any) Map { + v := reflect.ValueOf(value) + t := Tag{} + if v.Kind() == reflect.Array { + k := v.Type().Elem().Kind() + t.AsStrSet = k == reflect.String + t.AsNumSet = k >= reflect.Int && k <= reflect.Float64 && k != reflect.Uintptr + // t.AsBinSet is handled in encodeSlice() + } + av, _ := NewEncoder[any]().encode(v, t) + + m[key] = av + + return m +} diff --git a/feature/dynamodb/entitymanager/map_test.go b/feature/dynamodb/entitymanager/map_test.go new file mode 100644 index 00000000000..507cf9fb6a9 --- /dev/null +++ b/feature/dynamodb/entitymanager/map_test.go @@ -0,0 +1,106 @@ +package entitymanager + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestMap(t *testing.T) { + cases := map[string]struct { + input Map + expected map[string]types.AttributeValue + }{ + "string to map[string]AttributeValueMemberS": { + input: Map{}.With("k", "k"), + expected: map[string]types.AttributeValue{ + "k": &types.AttributeValueMemberS{Value: "k"}, + }, + }, + "string slice to map[string]AttributeValueMemberL": { + input: Map{}.With("k", []string{"k"}), + expected: map[string]types.AttributeValue{ + "k": &types.AttributeValueMemberL{ + Value: []types.AttributeValue{ + &types.AttributeValueMemberS{ + Value: "k", + }, + }, + }, + }, + }, + "string array to map[string]AttributeValueMemberSS": { + input: Map{}.With("k", [1]string{"k"}), + expected: map[string]types.AttributeValue{ + "k": &types.AttributeValueMemberSS{Value: []string{"k"}}, + }, + }, + "int to map[string]AttributeValueMemberN": { + input: Map{}.With("k", 1), + expected: map[string]types.AttributeValue{ + "k": &types.AttributeValueMemberN{Value: "1"}, + }, + }, + "float to map[string]AttributeValueMemberN": { + input: Map{}.With("k", 1.23), + expected: map[string]types.AttributeValue{ + "k": &types.AttributeValueMemberN{Value: "1.23"}, + }, + }, + "int array to map[string]AttributeValueMemberNS": { + input: Map{}.With("k", [1]int{1}), + expected: map[string]types.AttributeValue{ + "k": &types.AttributeValueMemberNS{Value: []string{"1"}}, + }, + }, + "byte slice to map[string]AttributeValueMemberB": { + input: Map{}.With("k", []byte("k")), + expected: map[string]types.AttributeValue{ + "k": &types.AttributeValueMemberB{ + Value: []byte("k"), + }, + }, + }, + "byte array to map[string]AttributeValueMemberBS": { + input: Map{}.With("k", [][]byte{[]byte("k")}), + expected: map[string]types.AttributeValue{ + "k": &types.AttributeValueMemberBS{ + Value: [][]byte{[]byte("k")}, + }, + }, + }, + "nil to map[string]AttributeValueMemberNULL": { + input: Map{}.With("k", nil), + expected: map[string]types.AttributeValue{ + "k": &types.AttributeValueMemberNULL{Value: true}, + }, + }, + "map slice to map[string]AttributeValueMemberM": { + input: Map{}.With("k", map[string]string{"k": "v"}), + expected: map[string]types.AttributeValue{ + "k": &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "k": &types.AttributeValueMemberS{ + Value: "v", + }, + }, + }, + }, + }, + "bool to map[string]AttributeValueMemberBOOL": { + input: Map{}.With("true", true).With("false", false), + expected: map[string]types.AttributeValue{ + "true": &types.AttributeValueMemberBOOL{Value: true}, + "false": &types.AttributeValueMemberBOOL{Value: false}, + }, + }, + } + + for k, c := range cases { + t.Run(k, func(t *testing.T) { + if diff := cmpDiff(c.input, Map(c.expected)); len(diff) > 0 { + t.Fatalf("unexpected diff: %s", diff) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/marshaler_test.go b/feature/dynamodb/entitymanager/marshaler_test.go new file mode 100644 index 00000000000..eb5a64d1ec7 --- /dev/null +++ b/feature/dynamodb/entitymanager/marshaler_test.go @@ -0,0 +1,769 @@ +package entitymanager + +import ( + "fmt" + "math" + "reflect" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +type simpleMarshalStruct struct { + Byte []byte + String string + PtrString *string + Int int + Uint uint + Float32 float32 + Float64 float64 + Bool bool + Null *interface{} +} + +type complexMarshalStruct struct { + Simple []simpleMarshalStruct +} + +type myByteStruct struct { + Byte []byte +} + +type myByteSetStruct struct { + ByteSet [][]byte +} + +type marshallerTestInput struct { + input interface{} + expected interface{} + err error +} + +var trueValue = true +var falseValue = false + +var marshalerScalarInputs = map[string]marshallerTestInput{ + "nil": { + input: nil, + expected: &types.AttributeValueMemberNULL{Value: true}, + }, + "string": { + input: "some string", + expected: &types.AttributeValueMemberS{Value: "some string"}, + }, + "bool": { + input: true, + expected: &types.AttributeValueMemberBOOL{Value: true}, + }, + "bool false": { + input: false, + expected: &types.AttributeValueMemberBOOL{Value: false}, + }, + "float": { + input: 3.14, + expected: &types.AttributeValueMemberN{Value: "3.14"}, + }, + "max float32": { + input: math.MaxFloat32, + expected: &types.AttributeValueMemberN{Value: "340282346638528860000000000000000000000"}, + }, + "max float64": { + input: math.MaxFloat64, + expected: &types.AttributeValueMemberN{Value: "179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}, + }, + "integer": { + input: 12, + expected: &types.AttributeValueMemberN{Value: "12"}, + }, + "number integer": { + input: Number("12"), + expected: &types.AttributeValueMemberN{Value: "12"}, + }, + "zero values": { + input: simpleMarshalStruct{}, + expected: &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Byte": &types.AttributeValueMemberNULL{Value: true}, + "Bool": &types.AttributeValueMemberBOOL{Value: false}, + "Float32": &types.AttributeValueMemberN{Value: "0"}, + "Float64": &types.AttributeValueMemberN{Value: "0"}, + "Int": &types.AttributeValueMemberN{Value: "0"}, + "Null": &types.AttributeValueMemberNULL{Value: true}, + "String": &types.AttributeValueMemberS{Value: ""}, + "PtrString": &types.AttributeValueMemberNULL{Value: true}, + "Uint": &types.AttributeValueMemberN{Value: "0"}, + }, + }, + }, +} + +var marshallerMapTestInputs = map[string]marshallerTestInput{ + // Scalar tests + "nil": { + input: nil, + expected: map[string]types.AttributeValue{}, + }, + "string": { + input: map[string]interface{}{"string": "some string"}, + expected: map[string]types.AttributeValue{"string": &types.AttributeValueMemberS{Value: "some string"}}, + }, + "bool": { + input: map[string]interface{}{"bool": true}, + expected: map[string]types.AttributeValue{"bool": &types.AttributeValueMemberBOOL{Value: true}}, + }, + "bool false": { + input: map[string]interface{}{"bool": false}, + expected: map[string]types.AttributeValue{"bool": &types.AttributeValueMemberBOOL{Value: false}}, + }, + "null": { + input: map[string]interface{}{"null": nil}, + expected: map[string]types.AttributeValue{"null": &types.AttributeValueMemberNULL{Value: true}}, + }, + "float": { + input: map[string]interface{}{"float": 3.14}, + expected: map[string]types.AttributeValue{"float": &types.AttributeValueMemberN{Value: "3.14"}}, + }, + "float32": { + input: map[string]interface{}{"float": math.MaxFloat32}, + expected: map[string]types.AttributeValue{"float": &types.AttributeValueMemberN{Value: "340282346638528860000000000000000000000"}}, + }, + "float64": { + input: map[string]interface{}{"float": math.MaxFloat64}, + expected: map[string]types.AttributeValue{"float": &types.AttributeValueMemberN{Value: "179769313486231570000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"}}, + }, + "decimal number": { + input: map[string]interface{}{"num": 12.}, + expected: map[string]types.AttributeValue{"num": &types.AttributeValueMemberN{Value: "12"}}, + }, + "byte": { + input: map[string]interface{}{"byte": []byte{48, 49}}, + expected: map[string]types.AttributeValue{"byte": &types.AttributeValueMemberB{Value: []byte{48, 49}}}, + }, + "nested blob": { + input: struct{ Byte []byte }{Byte: []byte{48, 49}}, + expected: map[string]types.AttributeValue{"Byte": &types.AttributeValueMemberB{Value: []byte{48, 49}}}, + }, + "map nested blob": { + input: map[string]interface{}{"byte_set": [][]byte{{48, 49}, {50, 51}}}, + expected: map[string]types.AttributeValue{"byte_set": &types.AttributeValueMemberBS{Value: [][]byte{{48, 49}, {50, 51}}}}, + }, + "bytes set": { + input: struct{ ByteSet [][]byte }{ByteSet: [][]byte{{48, 49}, {50, 51}}}, + expected: map[string]types.AttributeValue{"ByteSet": &types.AttributeValueMemberBS{Value: [][]byte{{48, 49}, {50, 51}}}}, + }, + "list": { + input: map[string]interface{}{"list": []interface{}{"a string", 12., 3.14, true, nil, false}}, + expected: map[string]types.AttributeValue{ + "list": &types.AttributeValueMemberL{ + Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "a string"}, + &types.AttributeValueMemberN{Value: "12"}, + &types.AttributeValueMemberN{Value: "3.14"}, + &types.AttributeValueMemberBOOL{Value: true}, + &types.AttributeValueMemberNULL{Value: true}, + &types.AttributeValueMemberBOOL{Value: false}, + }, + }, + }, + }, + "map": { + input: map[string]interface{}{"map": map[string]interface{}{"nestednum": 12.}}, + expected: map[string]types.AttributeValue{ + "map": &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "nestednum": &types.AttributeValueMemberN{Value: "12"}, + }, + }, + }, + }, + "struct": { + input: simpleMarshalStruct{}, + expected: map[string]types.AttributeValue{ + "Byte": &types.AttributeValueMemberNULL{Value: true}, + "Bool": &types.AttributeValueMemberBOOL{Value: false}, + "Float32": &types.AttributeValueMemberN{Value: "0"}, + "Float64": &types.AttributeValueMemberN{Value: "0"}, + "Int": &types.AttributeValueMemberN{Value: "0"}, + "Null": &types.AttributeValueMemberNULL{Value: true}, + "String": &types.AttributeValueMemberS{Value: ""}, + "PtrString": &types.AttributeValueMemberNULL{Value: true}, + "Uint": &types.AttributeValueMemberN{Value: "0"}, + }, + }, + "nested struct": { + input: complexMarshalStruct{}, + expected: map[string]types.AttributeValue{ + "Simple": &types.AttributeValueMemberNULL{Value: true}, + }, + }, + "nested nil slice": { + input: struct { + Simple []string `dynamodbav:"simple"` + }{}, + expected: map[string]types.AttributeValue{ + "simple": &types.AttributeValueMemberNULL{Value: true}, + }, + }, + "nested nil slice omit empty": { + input: struct { + Simple []string `dynamodbav:"simple,omitempty"` + }{}, + expected: map[string]types.AttributeValue{}, + }, + "nested ignored Field": { + input: struct { + Simple []string `dynamodbav:"-"` + }{}, + expected: map[string]types.AttributeValue{}, + }, + "complex struct members with zero": { + input: complexMarshalStruct{Simple: []simpleMarshalStruct{{Int: -2}, {Uint: 5}}}, + expected: map[string]types.AttributeValue{ + "Simple": &types.AttributeValueMemberL{ + Value: []types.AttributeValue{ + &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Byte": &types.AttributeValueMemberNULL{Value: true}, + "Bool": &types.AttributeValueMemberBOOL{Value: false}, + "Float32": &types.AttributeValueMemberN{Value: "0"}, + "Float64": &types.AttributeValueMemberN{Value: "0"}, + "Int": &types.AttributeValueMemberN{Value: "-2"}, + "Null": &types.AttributeValueMemberNULL{Value: true}, + "String": &types.AttributeValueMemberS{Value: ""}, + "PtrString": &types.AttributeValueMemberNULL{Value: true}, + "Uint": &types.AttributeValueMemberN{Value: "0"}, + }, + }, + &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Byte": &types.AttributeValueMemberNULL{Value: true}, + "Bool": &types.AttributeValueMemberBOOL{Value: false}, + "Float32": &types.AttributeValueMemberN{Value: "0"}, + "Float64": &types.AttributeValueMemberN{Value: "0"}, + "Int": &types.AttributeValueMemberN{Value: "0"}, + "Null": &types.AttributeValueMemberNULL{Value: true}, + "String": &types.AttributeValueMemberS{Value: ""}, + "PtrString": &types.AttributeValueMemberNULL{Value: true}, + "Uint": &types.AttributeValueMemberN{Value: "5"}, + }, + }, + }, + }, + }, + }, +} + +var marshallerListTestInputs = map[string]marshallerTestInput{ + "nil": { + input: nil, + expected: []types.AttributeValue{}, + }, + "empty interface": { + input: []interface{}{}, + expected: []types.AttributeValue{}, + }, + "empty struct": { + input: []simpleMarshalStruct{}, + expected: []types.AttributeValue{}, + }, + "various types": { + input: []interface{}{"a string", 12., 3.14, true, nil, false}, + expected: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "a string"}, + &types.AttributeValueMemberN{Value: "12"}, + &types.AttributeValueMemberN{Value: "3.14"}, + &types.AttributeValueMemberBOOL{Value: true}, + &types.AttributeValueMemberNULL{Value: true}, + &types.AttributeValueMemberBOOL{Value: false}, + }, + }, + "nested zero values": { + input: []simpleMarshalStruct{{}}, + expected: []types.AttributeValue{ + &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Byte": &types.AttributeValueMemberNULL{Value: true}, + "Bool": &types.AttributeValueMemberBOOL{Value: false}, + "Float32": &types.AttributeValueMemberN{Value: "0"}, + "Float64": &types.AttributeValueMemberN{Value: "0"}, + "Int": &types.AttributeValueMemberN{Value: "0"}, + "Null": &types.AttributeValueMemberNULL{Value: true}, + "String": &types.AttributeValueMemberS{Value: ""}, + "PtrString": &types.AttributeValueMemberNULL{Value: true}, + "Uint": &types.AttributeValueMemberN{Value: "0"}, + }, + }, + }, + }, +} + +func Test_New_Marshal(t *testing.T) { + for name, test := range marshalerScalarInputs { + t.Run(name, func(t *testing.T) { + actual, err := Marshal(test.input) + if test.err != nil { + if err == nil { + t.Errorf("Marshal with input %#v returned %#v, expected error `%s`", + test.input, actual, test.err) + } else if err.Error() != test.err.Error() { + t.Errorf("Marshal with input %#v returned error `%s`, expected error `%s`", + test.input, err, test.err) + } + } else { + if err != nil { + t.Errorf("Marshal with input %#v returned error `%s`", test.input, err) + } + compareObjects(t, test.expected, actual) + } + }) + } +} + +func Test_New_Unmarshal(t *testing.T) { + // Using the same inputs from Marshal, test the reverse mapping. + for name, test := range marshalerScalarInputs { + t.Run(name, func(t *testing.T) { + if test.input == nil { + t.Skip() + } + actual := reflect.New(reflect.TypeOf(test.input)).Interface() + if err := Unmarshal[any](test.expected.(types.AttributeValue), actual); err != nil { + t.Errorf("Unmarshal with input %#v returned error `%s`", test.expected, err) + } + compareObjects(t, test.input, reflect.ValueOf(actual).Elem().Interface()) + }) + } +} + +func Test_New_UnmarshalError(t *testing.T) { + // Test that we get an error using Unmarshal to convert to a nil value. + expected := &InvalidUnmarshalError{Type: reflect.TypeOf(nil)} + if err := Unmarshal[any](nil, nil); err == nil { + t.Errorf("Unmarshal with input %T returned no error, expected error `%v`", nil, expected) + } else if err.Error() != expected.Error() { + t.Errorf("Unmarshal with input %T returned error `%v`, expected error `%v`", nil, err, expected) + } + + // Test that we get an error using Unmarshal to convert to a non-pointer value. + var actual map[string]interface{} + expected = &InvalidUnmarshalError{Type: reflect.TypeOf(actual)} + if err := Unmarshal[any](nil, actual); err == nil { + t.Errorf("Unmarshal with input %T returned no error, expected error `%v`", actual, expected) + } else if err.Error() != expected.Error() { + t.Errorf("Unmarshal with input %T returned error `%v`, expected error `%v`", actual, err, expected) + } + + // Test that we get an error using Unmarshal to convert to nil struct. + var actual2 *struct{ A int } + expected = &InvalidUnmarshalError{Type: reflect.TypeOf(actual2)} + if err := Unmarshal[any](nil, actual2); err == nil { + t.Errorf("Unmarshal with input %T returned no error, expected error `%v`", actual2, expected) + } else if err.Error() != expected.Error() { + t.Errorf("Unmarshal with input %T returned error `%v`, expected error `%v`", actual2, err, expected) + } +} + +func Test_New_MarshalMap(t *testing.T) { + for name, test := range marshallerMapTestInputs { + t.Run(name, func(t *testing.T) { + actual, err := MarshalMap(test.input) + if test.err != nil { + if err == nil { + t.Errorf("MarshalMap with input %#v returned %#v, expected error `%s`", + test.input, actual, test.err) + } else if err.Error() != test.err.Error() { + t.Errorf("MarshalMap with input %#v returned error `%s`, expected error `%s`", + test.input, err, test.err) + } + } else { + if err != nil { + t.Errorf("MarshalMap with input %#v returned error `%s`", test.input, err) + } + compareObjects(t, test.expected, actual) + } + }) + } +} + +func Test_New_UnmarshalMap(t *testing.T) { + // Using the same inputs from MarshalMap, test the reverse mapping. + for name, test := range marshallerMapTestInputs { + t.Run(name, func(t *testing.T) { + if test.input == nil { + t.Skip() + } + actual := reflect.New(reflect.TypeOf(test.input)).Interface() + if err := UnmarshalMap(test.expected.(map[string]types.AttributeValue), actual); err != nil { + t.Errorf("Unmarshal with input %#v returned error `%s`", test.expected, err) + } + compareObjects(t, test.input, reflect.ValueOf(actual).Elem().Interface()) + }) + } +} + +func Test_New_UnmarshalMapError(t *testing.T) { + // Test that we get an error using UnmarshalMap to convert to a nil value. + expected := &InvalidUnmarshalError{Type: reflect.TypeOf(nil)} + if err := UnmarshalMap[any](nil, nil); err == nil { + t.Errorf("UnmarshalMap with input %T returned no error, expected error `%v`", nil, expected) + } else if err.Error() != expected.Error() { + t.Errorf("UnmarshalMap with input %T returned error `%v`, expected error `%v`", nil, err, expected) + } + + // Test that we get an error using UnmarshalMap to convert to a non-pointer value. + var actual map[string]interface{} + expected = &InvalidUnmarshalError{Type: reflect.TypeOf(actual)} + if err := UnmarshalMap(nil, actual); err == nil { + t.Errorf("UnmarshalMap with input %T returned no error, expected error `%v`", actual, expected) + } else if err.Error() != expected.Error() { + t.Errorf("UnmarshalMap with input %T returned error `%v`, expected error `%v`", actual, err, expected) + } + + // Test that we get an error using UnmarshalMap to convert to nil struct. + var actual2 *struct{ A int } + expected = &InvalidUnmarshalError{Type: reflect.TypeOf(actual2)} + if err := UnmarshalMap(nil, actual2); err == nil { + t.Errorf("UnmarshalMap with input %T returned no error, expected error `%v`", actual2, expected) + } else if err.Error() != expected.Error() { + t.Errorf("UnmarshalMap with input %T returned error `%v`, expected error `%v`", actual2, err, expected) + } +} + +func Test_New_MarshalList(t *testing.T) { + for name, c := range marshallerListTestInputs { + t.Run(name, func(t *testing.T) { + actual, err := MarshalList(c.input) + if c.err != nil { + if err == nil { + t.Fatalf("marshalList with input %#v returned %#v, expected error `%s`", + c.input, actual, c.err) + } else if err.Error() != c.err.Error() { + t.Fatalf("marshalList with input %#v returned error `%s`, expected error `%s`", + c.input, err, c.err) + } + return + } + if err != nil { + t.Fatalf("MarshalList with input %#v returned error `%s`", c.input, err) + } + + compareObjects(t, c.expected, actual) + + }) + } +} + +func Test_New_UnmarshalList(t *testing.T) { + // Using the same inputs from MarshalList, test the reverse mapping. + for name, c := range marshallerListTestInputs { + t.Run(name, func(t *testing.T) { + if c.input == nil { + t.Skip() + } + + iv := reflect.ValueOf(c.input) + + actual := reflect.New(iv.Type()) + if iv.Kind() == reflect.Slice { + actual.Elem().Set(reflect.MakeSlice(iv.Type(), iv.Len(), iv.Cap())) + } + + if err := UnmarshalList(c.expected.([]types.AttributeValue), actual.Interface()); err != nil { + t.Errorf("unmarshal with input %#v returned error `%s`", c.expected, err) + } + compareObjects(t, c.input, actual.Elem().Interface()) + }) + } +} + +func Test_New_UnmarshalListError(t *testing.T) { + // Test that we get an error using UnmarshalList to convert to a nil value. + expected := &InvalidUnmarshalError{Type: reflect.TypeOf(nil)} + if err := UnmarshalList[any](nil, nil); err == nil { + t.Errorf("UnmarshalList with input %T returned no error, expected error `%v`", nil, expected) + } else if err.Error() != expected.Error() { + t.Errorf("UnmarshalList with input %T returned error `%v`, expected error `%v`", nil, err, expected) + } + + // Test that we get an error using UnmarshalList to convert to a non-pointer value. + var actual map[string]interface{} + expected = &InvalidUnmarshalError{Type: reflect.TypeOf(actual)} + if err := UnmarshalList(nil, actual); err == nil { + t.Errorf("UnmarshalList with input %T returned no error, expected error `%v`", actual, expected) + } else if err.Error() != expected.Error() { + t.Errorf("UnmarshalList with input %T returned error `%v`, expected error `%v`", actual, err, expected) + } + + // Test that we get an error using UnmarshalList to convert to nil struct. + var actual2 *struct{ A int } + expected = &InvalidUnmarshalError{Type: reflect.TypeOf(actual2)} + if err := UnmarshalList(nil, actual2); err == nil { + t.Errorf("UnmarshalList with input %T returned no error, expected error `%v`", actual2, expected) + } else if err.Error() != expected.Error() { + t.Errorf("UnmarshalList with input %T returned error `%v`, expected error `%v`", actual2, err, expected) + } +} + +func compareObjects(t *testing.T, expected interface{}, actual interface{}) { + t.Helper() + if !reflect.DeepEqual(expected, actual) { + ev := reflect.ValueOf(expected) + av := reflect.ValueOf(actual) + if diff := cmpDiff(expected, actual); len(diff) != 0 { + t.Errorf("expect kind(%s, %T) match input kind(%s, %T)\n%s", + ev.Kind(), ev.Interface(), av.Kind(), av.Interface(), diff) + } + } +} + +func BenchmarkMarshalOneMember(b *testing.B) { + fieldCache = &fieldCacher{} + + simple := simpleMarshalStruct{ + String: "abc", + Int: 123, + Uint: 123, + Float32: 123.321, + Float64: 123.321, + Bool: true, + Null: nil, + } + type MyCompositeStruct struct { + A simpleMarshalStruct `dynamodbav:"a"` + } + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := Marshal(MyCompositeStruct{ + A: simple, + }); err != nil { + b.Error("unexpected error:", err) + } + } + }) +} + +func BenchmarkList20Ints(b *testing.B) { + input := []int{} + for i := 0; i < 20; i++ { + input = append(input, i) + } + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := Marshal(input) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkStruct10Fields(b *testing.B) { + + type struct10Fields struct { + Field1 int + Field2 string + Field3 int + Field4 string + Field5 string + Field6 string + Field7 int + Field8 string + Field9 int + Field10 int + } + + input := struct10Fields{ + Field1: 10, + Field2: "ASD", + Field3: 70, + Field4: "qqqqq", + Field5: "AAA", + Field6: "bbb", + Field7: 63, + Field8: "aa", + Field9: 10, + Field10: 63, + } + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, err := Marshal(input) + if err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkMarshalTwoMembers(b *testing.B) { + fieldCache = &fieldCacher{} + + simple := simpleMarshalStruct{ + String: "abc", + Int: 123, + Uint: 123, + Float32: 123.321, + Float64: 123.321, + Bool: true, + Null: nil, + } + + type MyCompositeStruct struct { + A simpleMarshalStruct `dynamodbav:"a"` + B simpleMarshalStruct `dynamodbav:"b"` + } + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := Marshal(MyCompositeStruct{ + A: simple, + B: simple, + }); err != nil { + b.Error("unexpected error:", err) + } + } + }) +} + +func BenchmarkUnmarshalOneMember(b *testing.B) { + fieldCache = &fieldCacher{} + + myStructAVMap, _ := Marshal(simpleMarshalStruct{ + String: "abc", + Int: 123, + Uint: 123, + Float32: 123.321, + Float64: 123.321, + Bool: true, + Null: nil, + }) + + type MyCompositeStructOne struct { + A simpleMarshalStruct `dynamodbav:"a"` + } + var out MyCompositeStructOne + avMap := map[string]types.AttributeValue{ + "a": myStructAVMap, + } + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := Unmarshal[any](&types.AttributeValueMemberM{Value: avMap}, &out); err != nil { + b.Error("unexpected error:", err) + } + } + }) +} + +func BenchmarkUnmarshalTwoMembers(b *testing.B) { + fieldCache = &fieldCacher{} + + myStructAVMap, _ := Marshal(simpleMarshalStruct{ + String: "abc", + Int: 123, + Uint: 123, + Float32: 123.321, + Float64: 123.321, + Bool: true, + Null: nil, + }) + + type MyCompositeStructTwo struct { + A simpleMarshalStruct `dynamodbav:"a"` + B simpleMarshalStruct `dynamodbav:"b"` + } + var out MyCompositeStructTwo + avMap := map[string]types.AttributeValue{ + "a": myStructAVMap, + "b": myStructAVMap, + } + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if err := Unmarshal[any](&types.AttributeValueMemberM{Value: avMap}, &out); err != nil { + b.Error("unexpected error:", err) + } + } + }) +} + +func Test_Encode_YAML_TagKey(t *testing.T) { + type Embedded struct { + String string `yaml:"string"` + } + + input := struct { + String string `yaml:"string"` + EmptyString string `yaml:"empty"` + OmitString string `yaml:"omitted,omitempty"` + Ignored string `yaml:"-"` + Byte []byte `yaml:"byte"` + Float32 float32 `yaml:"float32"` + Float64 float64 `yaml:"float64"` + Int int `yaml:"int"` + Uint uint `yaml:"uint"` + Slice []string `yaml:"slice"` + Map map[string]int `yaml:"map"` + NoTag string + Embedded `yaml:"embedded"` + }{ + String: "String", + Ignored: "Ignored", + Slice: []string{"one", "two"}, + Map: map[string]int{ + "one": 1, + "two": 2, + }, + NoTag: "NoTag", + Embedded: Embedded{ + String: "String", + }, + } + + expected := &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "string": &types.AttributeValueMemberS{Value: "String"}, + "empty": &types.AttributeValueMemberS{Value: ""}, + "byte": &types.AttributeValueMemberNULL{Value: true}, + "float32": &types.AttributeValueMemberN{Value: "0"}, + "float64": &types.AttributeValueMemberN{Value: "0"}, + "int": &types.AttributeValueMemberN{Value: "0"}, + "uint": &types.AttributeValueMemberN{Value: "0"}, + "slice": &types.AttributeValueMemberL{ + Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "one"}, + &types.AttributeValueMemberS{Value: "two"}, + }, + }, + "map": &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "one": &types.AttributeValueMemberN{Value: "1"}, + "two": &types.AttributeValueMemberN{Value: "2"}, + }, + }, + "NoTag": &types.AttributeValueMemberS{Value: "NoTag"}, + "embedded": &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "string": &types.AttributeValueMemberS{Value: "String"}, + }, + }, + }, + } + + enc := NewEncoder[any](func(o *EncoderOptions) { + o.TagKey = "yaml" + }) + + actual, err := enc.Encode(input) + if err != nil { + t.Errorf("Encode with input %#v returned error `%s`, expected nil", input, err) + } + + compareObjects(t, expected, actual) +} + +func cmpDiff(e, a interface{}) string { + if !reflect.DeepEqual(e, a) { + return fmt.Sprintf("%#+v != %#+v", e, a) + } + return "" +} diff --git a/feature/dynamodb/entitymanager/mock_client_test.go b/feature/dynamodb/entitymanager/mock_client_test.go new file mode 100644 index 00000000000..4d0b19f4711 --- /dev/null +++ b/feature/dynamodb/entitymanager/mock_client_test.go @@ -0,0 +1,691 @@ +package entitymanager + +import ( + "context" + "errors" + "fmt" + "maps" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +var _ Client = (*mockClient)(nil) + +type mockClient struct { + TableDescriptions map[string]types.TableDescription + Items map[string][]map[string]types.AttributeValue + + SetupFns []mockClientSetupFn + Expects []expectFn + + CreateTableCalls []ddbCall[dynamodb.CreateTableInput, dynamodb.CreateTableOutput] + DescribeTableCalls []ddbCall[dynamodb.DescribeTableInput, dynamodb.DescribeTableOutput] + DeleteTableCalls []ddbCall[dynamodb.DeleteTableInput, dynamodb.DeleteTableOutput] + GetItemCalls []ddbCall[dynamodb.GetItemInput, dynamodb.GetItemOutput] + PutItemCalls []ddbCall[dynamodb.PutItemInput, dynamodb.PutItemOutput] + DeleteItemCalls []ddbCall[dynamodb.DeleteItemInput, dynamodb.DeleteItemOutput] + UpdateItemCalls []ddbCall[dynamodb.UpdateItemInput, dynamodb.UpdateItemOutput] + BatchGetItemCalls []ddbCall[dynamodb.BatchGetItemInput, dynamodb.BatchGetItemOutput] + BatchWriteItemCalls []ddbCall[dynamodb.BatchWriteItemInput, dynamodb.BatchWriteItemOutput] + ScanCalls []ddbCall[dynamodb.ScanInput, dynamodb.ScanOutput] + QueryCalls []ddbCall[dynamodb.QueryInput, dynamodb.QueryOutput] +} + +func newMockClient(fns ...mockClientSetupFn) *mockClient { + out := &mockClient{} + + for _, fn := range fns { + fn(out) + } + + return out +} + +type ddbCall[I, O any] func(Client, context.Context, *I, ...func(*dynamodb.Options)) (*O, error) + +type ddbCallAsert[C, I, O any] func(*C, context.Context, *I, ...func(*dynamodb.Options)) (*O, error) + +var _ ddbCallAsert[dynamodb.Client, dynamodb.CreateTableInput, dynamodb.CreateTableOutput] = (*dynamodb.Client).CreateTable +var _ ddbCallAsert[dynamodb.Client, dynamodb.DescribeTableInput, dynamodb.DescribeTableOutput] = (*dynamodb.Client).DescribeTable +var _ ddbCallAsert[dynamodb.Client, dynamodb.DeleteTableInput, dynamodb.DeleteTableOutput] = (*dynamodb.Client).DeleteTable +var _ ddbCallAsert[dynamodb.Client, dynamodb.GetItemInput, dynamodb.GetItemOutput] = (*dynamodb.Client).GetItem +var _ ddbCallAsert[dynamodb.Client, dynamodb.PutItemInput, dynamodb.PutItemOutput] = (*dynamodb.Client).PutItem +var _ ddbCallAsert[dynamodb.Client, dynamodb.UpdateItemInput, dynamodb.UpdateItemOutput] = (*dynamodb.Client).UpdateItem +var _ ddbCallAsert[dynamodb.Client, dynamodb.DeleteItemInput, dynamodb.DeleteItemOutput] = (*dynamodb.Client).DeleteItem +var _ ddbCallAsert[dynamodb.Client, dynamodb.QueryInput, dynamodb.QueryOutput] = (*dynamodb.Client).Query +var _ ddbCallAsert[dynamodb.Client, dynamodb.ScanInput, dynamodb.ScanOutput] = (*dynamodb.Client).Scan + +func doDdbCall[I, O any]( + ctx context.Context, + client *mockClient, + callStack *[]ddbCall[I, O], + input *I, + optFns ...func(*dynamodb.Options), +) (*O, error) { + if callStack == nil || len(*callStack) == 0 { + return nil, fmt.Errorf(`unexpected call for %T`, callStack) + } + + call := (*callStack)[0] + *callStack = (*callStack)[1:] + + return call(client, ctx, input, optFns...) +} + +func (c *mockClient) CreateTable(ctx context.Context, input *dynamodb.CreateTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.CreateTableOutput, error) { + return doDdbCall(ctx, c, &c.CreateTableCalls, input, optFns...) +} + +func (c *mockClient) DescribeTable(ctx context.Context, input *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) { + return doDdbCall(ctx, c, &c.DescribeTableCalls, input, optFns...) +} + +func (c *mockClient) DeleteTable(ctx context.Context, input *dynamodb.DeleteTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteTableOutput, error) { + return doDdbCall(ctx, c, &c.DeleteTableCalls, input, optFns...) +} + +func (c *mockClient) GetItem(ctx context.Context, input *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error) { + return doDdbCall(ctx, c, &c.GetItemCalls, input, optFns...) +} + +func (c *mockClient) PutItem(ctx context.Context, input *dynamodb.PutItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.PutItemOutput, error) { + return doDdbCall(ctx, c, &c.PutItemCalls, input, optFns...) +} + +func (c *mockClient) DeleteItem(ctx context.Context, input *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error) { + return doDdbCall(ctx, c, &c.DeleteItemCalls, input, optFns...) +} + +func (c *mockClient) UpdateItem(ctx context.Context, input *dynamodb.UpdateItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateItemOutput, error) { + return doDdbCall(ctx, c, &c.UpdateItemCalls, input, optFns...) +} + +func (c *mockClient) BatchGetItem(ctx context.Context, input *dynamodb.BatchGetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.BatchGetItemOutput, error) { + return doDdbCall(ctx, c, &c.BatchGetItemCalls, input, optFns...) +} + +func (c *mockClient) BatchWriteItem(ctx context.Context, input *dynamodb.BatchWriteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.BatchWriteItemOutput, error) { + return doDdbCall(ctx, c, &c.BatchWriteItemCalls, input, optFns...) +} + +func (c *mockClient) Scan(ctx context.Context, input *dynamodb.ScanInput, optFns ...func(*dynamodb.Options)) (*dynamodb.ScanOutput, error) { + return doDdbCall(ctx, c, &c.ScanCalls, input, optFns...) +} + +func (c *mockClient) Query(ctx context.Context, input *dynamodb.QueryInput, optFns ...func(*dynamodb.Options)) (*dynamodb.QueryOutput, error) { + return doDdbCall(ctx, c, &c.QueryCalls, input, optFns...) +} + +func (c *mockClient) RunExpectations(t *testing.T) { + for _, fn := range c.Expects { + if err := fn(t, c); err != nil { + t.Errorf("expectation failed: %v", err) + } + } +} + +type mockClientSetupFn func(*mockClient) + +func withDefaultCreateTableCall(err error) mockClientSetupFn { + return func(m *mockClient) { + m.CreateTableCalls = append(m.CreateTableCalls, defaultCreateTableCall(m, err)) + } +} + +func defaultCreateTableCall(client *mockClient, err error) ddbCall[dynamodb.CreateTableInput, dynamodb.CreateTableOutput] { + return func(_ Client, _ context.Context, input *dynamodb.CreateTableInput, _ ...func(*dynamodb.Options)) (*dynamodb.CreateTableOutput, error) { + if err != nil { + return nil, err + } + + if client.TableDescriptions == nil { + client.TableDescriptions = make(map[string]types.TableDescription) + } + + tableName := aws.ToString(input.TableName) + if desc, found := client.TableDescriptions[tableName]; found { + return &dynamodb.CreateTableOutput{ + TableDescription: &desc, + }, nil + } + + desc := types.TableDescription{ + ArchivalSummary: nil, + AttributeDefinitions: input.AttributeDefinitions, + BillingModeSummary: func() *types.BillingModeSummary { + switch input.BillingMode { + case types.BillingModePayPerRequest: + return &types.BillingModeSummary{ + BillingMode: types.BillingModePayPerRequest, + LastUpdateToPayPerRequestDateTime: aws.Time(time.Now()), + } + case types.BillingModeProvisioned: + return &types.BillingModeSummary{ + BillingMode: types.BillingModeProvisioned, + LastUpdateToPayPerRequestDateTime: aws.Time(time.Now()), + } + default: + return nil + } + }(), + CreationDateTime: aws.Time(time.Now()), + DeletionProtectionEnabled: input.DeletionProtectionEnabled, + GlobalSecondaryIndexes: func() []types.GlobalSecondaryIndexDescription { + var out []types.GlobalSecondaryIndexDescription + for _, g := range input.GlobalSecondaryIndexes { + out = append(out, types.GlobalSecondaryIndexDescription{ + Backfilling: aws.Bool(false), + IndexArn: aws.String(fmt.Sprintf("arn:aws:dynamodb:eu-west-1:123456789012:table/%s/index/%s", tableName, *g.IndexName)), + IndexName: g.IndexName, + IndexSizeBytes: aws.Int64(0), + IndexStatus: types.IndexStatusActive, + ItemCount: aws.Int64(0), + KeySchema: g.KeySchema, + OnDemandThroughput: g.OnDemandThroughput, + Projection: g.Projection, + ProvisionedThroughput: func() *types.ProvisionedThroughputDescription { + if g.ProvisionedThroughput == nil { + return nil + } + + return &types.ProvisionedThroughputDescription{ + LastDecreaseDateTime: aws.Time(time.Now()), + LastIncreaseDateTime: aws.Time(time.Now()), + NumberOfDecreasesToday: aws.Int64(0), + ReadCapacityUnits: g.ProvisionedThroughput.ReadCapacityUnits, + WriteCapacityUnits: g.ProvisionedThroughput.WriteCapacityUnits, + } + }(), + WarmThroughput: func() *types.GlobalSecondaryIndexWarmThroughputDescription { + if g.WarmThroughput == nil { + return nil + } + + return &types.GlobalSecondaryIndexWarmThroughputDescription{ + ReadUnitsPerSecond: g.WarmThroughput.ReadUnitsPerSecond, + WriteUnitsPerSecond: g.WarmThroughput.WriteUnitsPerSecond, + Status: types.IndexStatusActive, + } + }(), + }) + } + return out + }(), + GlobalTableVersion: nil, + GlobalTableWitnesses: nil, + ItemCount: aws.Int64(0), + KeySchema: input.KeySchema, + LatestStreamArn: nil, + LatestStreamLabel: nil, + LocalSecondaryIndexes: func() []types.LocalSecondaryIndexDescription { + var out []types.LocalSecondaryIndexDescription + + for _, l := range input.LocalSecondaryIndexes { + out = append(out, types.LocalSecondaryIndexDescription{ + IndexArn: aws.String(fmt.Sprintf("arn:aws:dynamodb:eu-west-1:123456789012:table/%s/index/%s", tableName, *l.IndexName)), + IndexName: l.IndexName, + IndexSizeBytes: aws.Int64(0), + ItemCount: aws.Int64(0), + KeySchema: l.KeySchema, + Projection: l.Projection, + }) + } + + return out + }(), + MultiRegionConsistency: types.MultiRegionConsistencyEventual, + OnDemandThroughput: input.OnDemandThroughput, + ProvisionedThroughput: func() *types.ProvisionedThroughputDescription { + if input.ProvisionedThroughput == nil { + return nil + } + + return &types.ProvisionedThroughputDescription{ + LastDecreaseDateTime: aws.Time(time.Now()), + LastIncreaseDateTime: aws.Time(time.Now()), + NumberOfDecreasesToday: aws.Int64(0), + ReadCapacityUnits: input.ProvisionedThroughput.ReadCapacityUnits, + WriteCapacityUnits: input.ProvisionedThroughput.WriteCapacityUnits, + } + }(), + Replicas: nil, + RestoreSummary: nil, + SSEDescription: nil, + StreamSpecification: nil, + TableArn: aws.String(fmt.Sprintf("arn:aws:dynamodb:eu-west-1:123456789012:table/%s", tableName)), + TableClassSummary: nil, + TableId: nil, + TableName: input.TableName, + TableSizeBytes: aws.Int64(0), + TableStatus: types.TableStatusActive, + WarmThroughput: func() *types.TableWarmThroughputDescription { + if input.WarmThroughput == nil { + return nil + } + + return &types.TableWarmThroughputDescription{ + ReadUnitsPerSecond: input.WarmThroughput.ReadUnitsPerSecond, + WriteUnitsPerSecond: input.WarmThroughput.WriteUnitsPerSecond, + Status: types.TableStatusActive, + } + }(), + } + client.TableDescriptions[tableName] = desc + + return &dynamodb.CreateTableOutput{ + TableDescription: &desc, + }, nil + } +} + +func withDefaultDescribeTableCall(err error) mockClientSetupFn { + return func(m *mockClient) { + m.DescribeTableCalls = append(m.DescribeTableCalls, defaultDescribeTableCall(m, err)) + } +} + +func defaultDescribeTableCall(client *mockClient, err error) ddbCall[dynamodb.DescribeTableInput, dynamodb.DescribeTableOutput] { + return func(_ Client, _ context.Context, input *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) { + if err != nil { + return nil, err + } + + desc, found := client.TableDescriptions[aws.ToString(input.TableName)] + if !found { + return nil, fmt.Errorf("table %q not found", aws.ToString(input.TableName)) + } + + return &dynamodb.DescribeTableOutput{ + Table: &desc, + }, nil + } +} + +func withDefaultDeleteTableCall(err error) mockClientSetupFn { + return func(m *mockClient) { + m.DeleteTableCalls = append(m.DeleteTableCalls, defaultDeleteTableCall(m, err)) + } +} + +func defaultDeleteTableCall(client *mockClient, err error) ddbCall[dynamodb.DeleteTableInput, dynamodb.DeleteTableOutput] { + return func(_ Client, _ context.Context, input *dynamodb.DeleteTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteTableOutput, error) { + if err != nil { + return nil, err + } + + desc, found := client.TableDescriptions[aws.ToString(input.TableName)] + if !found { + return nil, fmt.Errorf("table %q not found", aws.ToString(input.TableName)) + } + + delete(client.TableDescriptions, aws.ToString(input.TableName)) + + return &dynamodb.DeleteTableOutput{ + TableDescription: &desc, + }, nil + } +} + +func withDefaultGetItemCall(err error) mockClientSetupFn { + return func(m *mockClient) { + m.GetItemCalls = append(m.GetItemCalls, defaultGetItemCall(m, err)) + } +} + +func defaultGetItemCall(client *mockClient, err error) ddbCall[dynamodb.GetItemInput, dynamodb.GetItemOutput] { + return func(_ Client, _ context.Context, input *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error) { + if err != nil { + return nil, err + } + + tableName := aws.ToString(input.TableName) + + if len(client.Items) == 0 || len(client.Items[tableName]) == 0 { + return &dynamodb.GetItemOutput{ + Item: nil, + }, nil + } + + item := client.Items[tableName][0] + client.Items[tableName] = client.Items[tableName][1:] + + return &dynamodb.GetItemOutput{ + Item: item, + }, nil + } +} + +func withDefaultPutItemCall(err error) mockClientSetupFn { + return func(m *mockClient) { + m.PutItemCalls = append(m.PutItemCalls, defaultPutItemCall(m, err)) + } +} + +func defaultPutItemCall(client *mockClient, err error) ddbCall[dynamodb.PutItemInput, dynamodb.PutItemOutput] { + return func(_ Client, _ context.Context, input *dynamodb.PutItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.PutItemOutput, error) { + if err != nil { + return nil, err + } + + tableName := aws.ToString(input.TableName) + + if client.Items == nil { + client.Items = make(map[string][]map[string]types.AttributeValue) + } + + client.Items[tableName] = append(client.Items[tableName], input.Item) + + return &dynamodb.PutItemOutput{ + Attributes: input.Item, + }, nil + } +} + +func withDefaultDeleteItemCall(err error) mockClientSetupFn { + return func(m *mockClient) { + m.DeleteItemCalls = append(m.DeleteItemCalls, defaultDeleteItemCall(m, err)) + } +} + +func defaultDeleteItemCall(client *mockClient, err error) ddbCall[dynamodb.DeleteItemInput, dynamodb.DeleteItemOutput] { + return func(_ Client, _ context.Context, input *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error) { + if err != nil { + return nil, err + } + + tableName := aws.ToString(input.TableName) + + if len(client.Items) == 0 || len(client.Items[tableName]) == 0 { + return &dynamodb.DeleteItemOutput{ + Attributes: nil, + }, nil + } + + item := client.Items[tableName][0] + client.Items[tableName] = client.Items[tableName][1:] + + return &dynamodb.DeleteItemOutput{ + Attributes: item, + }, err + } +} + +func withDefaultUpdateItemCall(err error) mockClientSetupFn { + return func(m *mockClient) { + m.UpdateItemCalls = append(m.UpdateItemCalls, defaultUpdateItemCall(m, err)) + } +} + +func defaultUpdateItemCall(client *mockClient, err error) ddbCall[dynamodb.UpdateItemInput, dynamodb.UpdateItemOutput] { + return func(_ Client, _ context.Context, input *dynamodb.UpdateItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateItemOutput, error) { + if err != nil { + return nil, err + } + + tableName := aws.ToString(input.TableName) + + item := map[string]types.AttributeValue{} + maps.Copy(item, input.Key) + + for k, v := range input.ExpressionAttributeValues { + if nk, found := input.ExpressionAttributeNames[k]; found { + k = nk + } + + item[k] = v + } + + if client.Items == nil { + client.Items = make(map[string][]map[string]types.AttributeValue) + } + + // naive implementation; always assume insert since update supports upserts + client.Items[tableName] = append(client.Items[tableName], item) + + return &dynamodb.UpdateItemOutput{ + Attributes: item, + }, nil + } +} + +func withDefaultBatchGetItemCall(err error, retCounts map[string]uint) mockClientSetupFn { + return func(m *mockClient) { + m.BatchGetItemCalls = append(m.BatchGetItemCalls, defaultBatchGetItemCall(m, err, retCounts)) + } +} + +func defaultBatchGetItemCall(client *mockClient, err error, retCounts map[string]uint) ddbCall[dynamodb.BatchGetItemInput, dynamodb.BatchGetItemOutput] { + return func(_ Client, _ context.Context, input *dynamodb.BatchGetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.BatchGetItemOutput, error) { + if err != nil { + return nil, err + } + + if len(retCounts) == 0 { + return &dynamodb.BatchGetItemOutput{}, nil + } + + if len(client.Items) == 0 { + return nil, errors.New("items have already been exhausted") + } + + out := &dynamodb.BatchGetItemOutput{ + Responses: make(map[string][]map[string]types.AttributeValue), + } + + for tableName, retCount := range retCounts { + // allow tests to force 0 responses + if retCount == 0 { + continue + } + + items := client.Items[tableName][0:retCount] + client.Items[tableName] = client.Items[tableName][len(items):] + + out.Responses[tableName] = items + + if len(client.Items[tableName]) > 0 { + if out.UnprocessedKeys == nil { + out.UnprocessedKeys = make(map[string]types.KeysAndAttributes) + } + out.UnprocessedKeys[tableName] = types.KeysAndAttributes{ + Keys: input.RequestItems[tableName].Keys[len(items):], + } + } + } + + return out, nil + } +} + +func withDefaultBatchWriteItemCall(err error, retCounts map[string]uint) mockClientSetupFn { + return func(m *mockClient) { + m.BatchWriteItemCalls = append(m.BatchWriteItemCalls, defaultBatchWriteItemCall(m, err, retCounts)) + } +} + +func defaultBatchWriteItemCall(client *mockClient, err error, retCounts map[string]uint) ddbCall[dynamodb.BatchWriteItemInput, dynamodb.BatchWriteItemOutput] { + return func(_ Client, _ context.Context, input *dynamodb.BatchWriteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.BatchWriteItemOutput, error) { + if err != nil { + return nil, err + } + + if len(retCounts) == 0 { + return &dynamodb.BatchWriteItemOutput{}, nil + } + + if len(client.Items) == 0 { + return nil, errors.New("items have already been exhausted") + } + + out := &dynamodb.BatchWriteItemOutput{} + + for tableName, retCount := range retCounts { + // allow tests to force 0 responses + if retCount == 0 { + continue + } + + if len(input.RequestItems[tableName]) < int(retCount) { + continue + } + + items := input.RequestItems[tableName][:retCount] + for _, i := range items { + if i.PutRequest != nil { + client.Items[tableName] = append(client.Items[tableName], i.PutRequest.Item) + } + if i.DeleteRequest != nil { + client.Items[tableName] = client.Items[tableName][1:] + } + } + + if len(client.Items[tableName]) > 0 { + if out.UnprocessedItems == nil { + out.UnprocessedItems = make(map[string][]types.WriteRequest) + } + + out.UnprocessedItems[tableName] = input.RequestItems[tableName][retCount:] + } + } + + return out, nil + } +} + +func withDefaultScanCall(err error, retCount uint) mockClientSetupFn { + return func(m *mockClient) { + m.ScanCalls = append(m.ScanCalls, defaultScanCall(m, err, retCount)) + } +} + +func defaultScanCall(client *mockClient, err error, retCount uint) ddbCall[dynamodb.ScanInput, dynamodb.ScanOutput] { + return func(_ Client, _ context.Context, input *dynamodb.ScanInput, optFns ...func(*dynamodb.Options)) (*dynamodb.ScanOutput, error) { + if err != nil { + return nil, err + } + + if retCount == 0 { + return &dynamodb.ScanOutput{}, nil + } + + if len(client.Items) == 0 { + return nil, errors.New("items have already been exhausted") + } + + tableName := aws.ToString(input.TableName) + + items := client.Items[tableName][0:retCount] + out := &dynamodb.ScanOutput{ + Items: items, + } + + client.Items[tableName] = client.Items[tableName][len(items):] + + if len(client.Items) > 0 { + out.LastEvaluatedKey = items[len(items)-1] + } + + return out, nil + } +} + +func withDefaultQueryCall(err error, retCount uint) mockClientSetupFn { + return func(m *mockClient) { + m.QueryCalls = append(m.QueryCalls, defaultQueryCall(m, err, retCount)) + } +} + +func defaultQueryCall(client *mockClient, err error, retCount uint) ddbCall[dynamodb.QueryInput, dynamodb.QueryOutput] { + return func(_ Client, _ context.Context, input *dynamodb.QueryInput, optFns ...func(*dynamodb.Options)) (*dynamodb.QueryOutput, error) { + if err != nil { + return nil, err + } + + if retCount == 0 { + return &dynamodb.QueryOutput{}, nil + } + + if len(client.Items) == 0 { + return nil, errors.New("items have already been exhausted") + } + + tableName := aws.ToString(input.TableName) + + items := client.Items[tableName][0:retCount] + out := &dynamodb.QueryOutput{ + Items: items, + } + + client.Items[tableName] = client.Items[tableName][len(items):] + + if len(client.Items) > 0 { + out.LastEvaluatedKey = items[len(items)-1] + } + + return out, nil + } +} + +func withItem(tableName string, item map[string]types.AttributeValue) mockClientSetupFn { + return func(m *mockClient) { + if m.Items == nil { + m.Items = make(map[string][]map[string]types.AttributeValue) + } + + m.Items[tableName] = append(m.Items[tableName], item) + } +} + +func withItems(tableName string, generator func() map[string]types.AttributeValue, count uint) mockClientSetupFn { + return func(m *mockClient) { + if m.Items == nil { + m.Items = make(map[string][]map[string]types.AttributeValue) + } + + for i := count; i > 0; i-- { + m.Items[tableName] = append(m.Items[tableName], generator()) + } + } +} + +type expectFn func(*testing.T, *mockClient) error + +func withExpectFns(fn expectFn) mockClientSetupFn { + return func(m *mockClient) { + m.Expects = append(m.Expects, fn) + } +} + +func expectTablesCount(c uint) expectFn { + return func(t *testing.T, m *mockClient) error { + if len(m.TableDescriptions) != int(c) { + return fmt.Errorf("expected %d tables, but found %d", c, len(m.TableDescriptions)) + } + + return nil + } +} + +func expectTable(tableName string) expectFn { + return func(t *testing.T, m *mockClient) error { + if _, found := m.TableDescriptions[tableName]; !found { + return fmt.Errorf("expected table %q not found", tableName) + } + + return nil + } +} + +func expectItemsCount(tableName string, c uint) expectFn { + return func(t *testing.T, m *mockClient) error { + if len(m.Items[tableName]) != int(c) { + return fmt.Errorf("expected %d items, but found %d", c, len(m.Items[tableName])) + } + + return nil + } +} diff --git a/feature/dynamodb/entitymanager/schema.go b/feature/dynamodb/entitymanager/schema.go new file mode 100644 index 00000000000..29059793f63 --- /dev/null +++ b/feature/dynamodb/entitymanager/schema.go @@ -0,0 +1,143 @@ +package entitymanager + +import ( + "fmt" + "reflect" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// Schema defines the structure and metadata for a DynamoDB table item of type T. +// It encapsulates table configuration, key schema, attribute definitions, and options +// for encoding/decoding items and managing table operations. +type Schema[T any] struct { + options SchemaOptions + cachedFields *CachedFields + enc *Encoder[T] + dec *Decoder[T] + typ reflect.Type + + // common + attributeDefinitions []types.AttributeDefinition + keySchema []types.KeySchemaElement + tableName *string + billingMode types.BillingMode + deletionProtectionEnabled *bool + onDemandThroughput *types.OnDemandThroughput + provisionedThroughput *types.ProvisionedThroughput + sseSpecification *types.SSESpecification + streamSpecification *types.StreamSpecification + tableClass types.TableClass + warmThroughput *types.WarmThroughput + // create + globalSecondaryIndexes []types.GlobalSecondaryIndex + localSecondaryIndexes []types.LocalSecondaryIndex + resourcePolicy *string + tags []types.Tag + // update + multiRegionConsistency types.MultiRegionConsistency + replicaUpdates []types.ReplicationGroupUpdate +} + +// createTableInput constructs a CreateTableInput for the DynamoDB table defined by this schema. +// It uses the schema's configuration and options to build the request. +func (s *Schema[T]) createTableInput() (*dynamodb.CreateTableInput, error) { + return &dynamodb.CreateTableInput{ + TableName: s.TableName(), + KeySchema: s.KeySchema(), + AttributeDefinitions: s.AttributeDefinitions(), + BillingMode: s.BillingMode(), + DeletionProtectionEnabled: s.DeletionProtectionEnabled(), + GlobalSecondaryIndexes: s.GlobalSecondaryIndexes(), + LocalSecondaryIndexes: s.LocalSecondaryIndexes(), + OnDemandThroughput: s.OnDemandThroughput(), + ProvisionedThroughput: s.ProvisionedThroughput(), + ResourcePolicy: s.ResourcePolicy(), + SSESpecification: s.SSESpecification(), + StreamSpecification: s.StreamSpecification(), + TableClass: s.TableClass(), + Tags: s.Tags(), + WarmThroughput: s.WarmThroughput(), + }, nil +} + +// describeTableInput constructs a DescribeTableInput for the DynamoDB table defined by this schema. +// It returns the request for describing the table's metadata and status. +func (s *Schema[T]) describeTableInput() (*dynamodb.DescribeTableInput, error) { + return &dynamodb.DescribeTableInput{ + TableName: s.TableName(), + }, nil +} + +// deleteTableInput constructs a DeleteTableInput for the DynamoDB table defined by this schema. +// It returns the request for deleting the table. +func (s *Schema[T]) deleteTableInput() (*dynamodb.DeleteTableInput, error) { + return &dynamodb.DeleteTableInput{ + TableName: s.TableName(), + }, nil +} + +// createKeyMap generates a key map for the given item using the schema's key definition. +// The returned map can be used for DynamoDB key-based operations (e.g., GetItem, DeleteItem). +func (s *Schema[T]) createKeyMap(item *T) (Map, error) { + m, err := s.Encode(item) + if err != nil { + return nil, err + } + + for _, f := range s.cachedFields.fields { + if !f.Partition && !f.Sort { + delete(m, f.Name) + } + } + + return m, nil +} + +// NewSchema creates a new Schema[T] instance for the given item type T. +// Optional configuration functions can be provided to customize schema options. +func NewSchema[T any](fns ...func(options *SchemaOptions)) (*Schema[T], error) { + if reflect.TypeFor[T]().Kind() != reflect.Struct { + return nil, fmt.Errorf("NewClient() can only be created from structs, %T given", *new(T)) + } + + t := new(T) + cf := unionStructFields(reflect.TypeOf(*t), structFieldOptions{}) + + opts := SchemaOptions{} + + for _, fn := range fns { + fn(&opts) + } + + s := &Schema[T]{ + options: opts, + cachedFields: cf, + typ: reflect.TypeFor[T](), + enc: NewEncoder[T](func(options *EncoderOptions) { + options.ConverterRegistry = opts.ConverterRegistry + options.IgnoreNilValueErrors = opts.IgnoreNilValueErrors + }), + dec: NewDecoder[T](func(options *DecoderOptions) { + options.ConverterRegistry = opts.ConverterRegistry + options.IgnoreNilValueErrors = opts.IgnoreNilValueErrors + }), + } + + resolversFns := []func(o *Schema[T]) error{ + (*Schema[T]).defaults, + (*Schema[T]).resolveTableName, + (*Schema[T]).resolveKeySchema, + (*Schema[T]).resolveAttributeDefinitions, + (*Schema[T]).resolveSecondaryIndexes, + } + + for _, fn := range resolversFns { + if err := fn(s); err != nil { + return nil, err + } + } + + return s, nil +} diff --git a/feature/dynamodb/entitymanager/schema_builder_methods.go b/feature/dynamodb/entitymanager/schema_builder_methods.go new file mode 100644 index 00000000000..1ab8445d2cb --- /dev/null +++ b/feature/dynamodb/entitymanager/schema_builder_methods.go @@ -0,0 +1,282 @@ +package entitymanager + +import ( + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// AttributeDefinitions returns the attribute definitions for the DynamoDB table schema. +func (s *Schema[T]) AttributeDefinitions() []types.AttributeDefinition { + return s.attributeDefinitions +} + +// WithAttributeDefinitions sets the attribute definitions for the schema and returns the updated schema. +func (s *Schema[T]) WithAttributeDefinitions(attributeDefinitions []types.AttributeDefinition) *Schema[T] { + s.attributeDefinitions = attributeDefinitions + + return s +} + +// KeySchema returns the key schema elements for the DynamoDB table. +func (s *Schema[T]) KeySchema() []types.KeySchemaElement { + return s.keySchema +} + +// WithKeySchema sets the key schema for the table and returns the updated schema. +func (s *Schema[T]) WithKeySchema(keySchema []types.KeySchemaElement) *Schema[T] { + s.keySchema = keySchema + + return s +} + +// TableName returns the name of the DynamoDB table. +func (s *Schema[T]) TableName() *string { + return s.tableName +} + +// WithTableName sets the table name and returns the updated schema. +func (s *Schema[T]) WithTableName(tableName *string) *Schema[T] { + s.tableName = tableName + + return s +} + +// BillingMode returns the billing mode for the DynamoDB table. +func (s *Schema[T]) BillingMode() types.BillingMode { + return s.billingMode +} + +// WithBillingMode sets the billing mode for the table and returns the updated schema. +func (s *Schema[T]) WithBillingMode(billingMode types.BillingMode) *Schema[T] { + s.billingMode = billingMode + + return s +} + +// DeletionProtectionEnabled returns whether deletion protection is enabled for the table. +func (s *Schema[T]) DeletionProtectionEnabled() *bool { + return s.deletionProtectionEnabled +} + +// WithDeletionProtectionEnabled sets deletion protection for the table and returns the updated schema. +func (s *Schema[T]) WithDeletionProtectionEnabled(deletionProtectionEnabled *bool) *Schema[T] { + s.deletionProtectionEnabled = deletionProtectionEnabled + + return s +} + +// GlobalSecondaryIndexes returns the global secondary indexes for the table. +func (s *Schema[T]) GlobalSecondaryIndexes() []types.GlobalSecondaryIndex { + if len(s.globalSecondaryIndexes) == 0 { + return nil + } + + return s.globalSecondaryIndexes +} + +// WithGlobalSecondaryIndexes overwrites the global secondary indexes and returns the updated schema. +func (s *Schema[T]) WithGlobalSecondaryIndexes(globalSecondaryIndexes []types.GlobalSecondaryIndex) *Schema[T] { + s.globalSecondaryIndexes = globalSecondaryIndexes + + return s +} + +// WithGlobalSecondaryIndex creates or updates a global secondary index by name using the provided function. +// If the index does not exist, it is created. Returns the updated schema. +func (s *Schema[T]) WithGlobalSecondaryIndex(name string, fn func(gsi *types.GlobalSecondaryIndex)) *Schema[T] { + var gsi *types.GlobalSecondaryIndex + for idx := range s.globalSecondaryIndexes { + gsi = &s.globalSecondaryIndexes[idx] + if gsi.IndexName != nil && *gsi.IndexName == name { + fn(gsi) + break + } + } + + if gsi == nil { + gsi = &types.GlobalSecondaryIndex{ + IndexName: pointer(name), + Projection: &types.Projection{ + ProjectionType: types.ProjectionTypeAll, + }, + } + fn(gsi) + s.globalSecondaryIndexes = append(s.globalSecondaryIndexes, *gsi) + } + + attrs := map[string]bool{} + for _, ad := range s.attributeDefinitions { + attrs[*ad.AttributeName] = true + } + + for _, ks := range gsi.KeySchema { + if _, ok := attrs[*ks.AttributeName]; !ok { + f, _ := s.cachedFields.FieldByName(*ks.AttributeName) + at, _ := typeToScalarAttributeType(f.Type) + s.attributeDefinitions = append(s.attributeDefinitions, types.AttributeDefinition{ + AttributeName: ks.AttributeName, + AttributeType: at, + }) + } + } + + return s +} + +// LocalSecondaryIndexes returns the local secondary indexes for the table. +func (s *Schema[T]) LocalSecondaryIndexes() []types.LocalSecondaryIndex { + if len(s.localSecondaryIndexes) == 0 { + return nil + } + + return s.localSecondaryIndexes +} + +// WithLocalSecondaryIndexes overwrites the local secondary indexes and returns the updated schema. +func (s *Schema[T]) WithLocalSecondaryIndexes(localSecondaryIndexes []types.LocalSecondaryIndex) *Schema[T] { + s.localSecondaryIndexes = localSecondaryIndexes + + return s +} + +// WithLocalSecondaryIndex creates or updates a local secondary index by name using the provided function. +// If the index does not exist, it is created. Returns the updated schema. +func (s *Schema[T]) WithLocalSecondaryIndex(name string, fn func(gsi *types.LocalSecondaryIndex)) *Schema[T] { + existing := false + for idx := range s.localSecondaryIndexes { + lsi := s.localSecondaryIndexes[idx] + if lsi.IndexName != nil && *lsi.IndexName == name { + fn(&lsi) + existing = true + } + } + + if !existing { + lsi := types.LocalSecondaryIndex{ + IndexName: pointer(name), + } + fn(&lsi) + s.localSecondaryIndexes = append(s.localSecondaryIndexes, lsi) + } + + return s +} + +// OnDemandThroughput returns the on-demand throughput settings for the table. +func (s *Schema[T]) OnDemandThroughput() *types.OnDemandThroughput { + return s.onDemandThroughput +} + +// WithOnDemandThroughput sets the on-demand throughput and returns the updated schema. +func (s *Schema[T]) WithOnDemandThroughput(onDemandThroughput *types.OnDemandThroughput) *Schema[T] { + s.onDemandThroughput = onDemandThroughput + + return s +} + +// ProvisionedThroughput returns the provisioned throughput settings for the table. +func (s *Schema[T]) ProvisionedThroughput() *types.ProvisionedThroughput { + return s.provisionedThroughput +} + +// WithProvisionedThroughput sets the provisioned throughput and returns the updated schema. +func (s *Schema[T]) WithProvisionedThroughput(provisionedThroughput *types.ProvisionedThroughput) *Schema[T] { + s.provisionedThroughput = provisionedThroughput + + return s +} + +// ResourcePolicy returns the resource policy for the table. +func (s *Schema[T]) ResourcePolicy() *string { + return s.resourcePolicy +} + +// WithResourcePolicy sets the resource policy and returns the updated schema. +func (s *Schema[T]) WithResourcePolicy(resourcePolicy *string) *Schema[T] { + s.resourcePolicy = resourcePolicy + + return s +} + +// SSESpecification returns the server-side encryption specification for the table. +func (s *Schema[T]) SSESpecification() *types.SSESpecification { + return s.sseSpecification +} + +// WithSSESpecification sets the server-side encryption specification and returns the updated schema. +func (s *Schema[T]) WithSSESpecification(sseSpecification *types.SSESpecification) *Schema[T] { + s.sseSpecification = sseSpecification + + return s +} + +// StreamSpecification returns the stream specification for the table. +func (s *Schema[T]) StreamSpecification() *types.StreamSpecification { + return s.streamSpecification +} + +// WithStreamSpecification sets the stream specification and returns the updated schema. +func (s *Schema[T]) WithStreamSpecification(streamSpecification *types.StreamSpecification) *Schema[T] { + s.streamSpecification = streamSpecification + + return s +} + +// TableClass returns the table class for the DynamoDB table. +func (s *Schema[T]) TableClass() types.TableClass { + return s.tableClass +} + +// WithTableClass sets the table class and returns the updated schema. +func (s *Schema[T]) WithTableClass(tableClass types.TableClass) *Schema[T] { + s.tableClass = tableClass + + return s +} + +// Tags returns the tags associated with the table. +func (s *Schema[T]) Tags() []types.Tag { + return s.tags +} + +// WithTags sets the tags for the table and returns the updated schema. +func (s *Schema[T]) WithTags(tags []types.Tag) *Schema[T] { + s.tags = tags + + return s +} + +// WarmThroughput returns the warm throughput settings for the table. +func (s *Schema[T]) WarmThroughput() *types.WarmThroughput { + return s.warmThroughput +} + +// WithWarmThroughput sets the warm throughput and returns the updated schema. +func (s *Schema[T]) WithWarmThroughput(warmThroughput *types.WarmThroughput) *Schema[T] { + s.warmThroughput = warmThroughput + + return s +} + +// MultiRegionConsistency returns the multi-region consistency setting for the table. +func (s *Schema[T]) MultiRegionConsistency() types.MultiRegionConsistency { + return s.multiRegionConsistency +} + +// WithMultiRegionConsistency sets the multi-region consistency and returns the updated schema. +func (s *Schema[T]) WithMultiRegionConsistency(multiRegionConsistency types.MultiRegionConsistency) *Schema[T] { + s.multiRegionConsistency = multiRegionConsistency + + return s +} + +// ReplicaUpdates returns the replication group updates for the table. +func (s *Schema[T]) ReplicaUpdates() []types.ReplicationGroupUpdate { + return s.replicaUpdates +} + +// WithReplicaUpdates sets the replication group updates and returns the updated schema. +func (s *Schema[T]) WithReplicaUpdates(replicaUpdates []types.ReplicationGroupUpdate) *Schema[T] { + s.replicaUpdates = replicaUpdates + + return s +} diff --git a/feature/dynamodb/entitymanager/schema_options.go b/feature/dynamodb/entitymanager/schema_options.go new file mode 100644 index 00000000000..333016343a7 --- /dev/null +++ b/feature/dynamodb/entitymanager/schema_options.go @@ -0,0 +1,24 @@ +package entitymanager + +import "github.com/aws/aws-sdk-go-v2/feature/dynamodb/entitymanager/converters" + +// SchemaOptions defines configuration options for Schema behavior. +type SchemaOptions struct { + // ErrorOnMissingField controls whether decoding should return an error + // when a field is missing in the destination struct. + // If true, decoding will fail when the schema field cannot be matched. + // If false or nil, missing fields will be ignored. + ErrorOnMissingField *bool + + // IgnoreNilValueErrors controls whether decoding should ignore errors + // caused by nil values during schema conversion. + // If true, fields with nil values that cause conversion errors will be skipped. + // If false or nil, such cases will trigger an error. + IgnoreNilValueErrors *bool + + // ConverterRegistry provides a registry of type converters used during + // encoding and decoding operations. It will be set on both the Decoder + // and Encoder to control how values are transformed between Go types + // and schema representations. + ConverterRegistry *converters.Registry +} diff --git a/feature/dynamodb/entitymanager/schema_resolvers.go b/feature/dynamodb/entitymanager/schema_resolvers.go new file mode 100644 index 00000000000..3c09bbb9606 --- /dev/null +++ b/feature/dynamodb/entitymanager/schema_resolvers.go @@ -0,0 +1,328 @@ +package entitymanager + +import ( + "fmt" + "reflect" + "slices" + "strings" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// defaults sets default values for the Schema[T] if not already specified. +// Ensures billing mode is set to PayPerRequest if unset. +func (s *Schema[T]) defaults() error { + if s.billingMode == "" { + s.billingMode = types.BillingModePayPerRequest + } + + return nil +} + +// resolveTableName determines and sets the DynamoDB table name for the schema based on the type T. +// Returns an error if T is not a struct or pointer to struct. +func (s *Schema[T]) resolveTableName() error { + if s.typ == nil { + s.typ = reflect.TypeFor[T]() + } + + r := s.typ + if r.Kind() == reflect.Ptr { + r = r.Elem() + } + + if r.Kind() != reflect.Struct { + return fmt.Errorf("resolveTableName() expected the type to be a struct or struct pointer, got: %V", reflect.New(s.typ).Interface()) + } + + if s.tableName == nil { + s.tableName = pointer(r.Name()) + } + + return nil +} + +// resolveKeySchema analyzes the cached fields and sets the key schema for the table. +// Ensures exactly one partition key and at most one sort key are defined. +// Returns an error if the key configuration is invalid. +func (s *Schema[T]) resolveKeySchema() error { + if len(s.keySchema) > 0 { + return nil + } + + var primary []string + var sort []string + + for _, f := range s.cachedFields.fields { + if f.Tag.Partition && f.Tag.Sort { + return fmt.Errorf("Field %s is both primary and sort", f.Name) + } + + if f.Tag.Partition { + primary = append(primary, f.Name) + } + + if f.Tag.Sort { + sort = append(sort, f.Name) + } + } + + cp := len(primary) + if cp != 1 { + return fmt.Errorf("exactly 1 partition field is expected, %d given, fields: %s", len(primary), strings.Join(primary, ", ")) + } + + cs := len(sort) + if cs > 1 { + return fmt.Errorf("exactly 0 or 1 sort field is expected, %d given, fields: %s", len(sort), strings.Join(sort, ", ")) + } + + s.keySchema = make([]types.KeySchemaElement, cp+cs) + s.keySchema[0].AttributeName = &primary[0] + s.keySchema[0].KeyType = types.KeyTypeHash + + if cs > 0 { + s.keySchema[1].AttributeName = &sort[0] + s.keySchema[1].KeyType = types.KeyTypeRange + } + + return nil +} + +// resolveAttributeDefinitions populates the attribute definitions for the table based on key fields and indexes. +// Only fields used as keys or indexes are included. +func (s *Schema[T]) resolveAttributeDefinitions() error { + for _, f := range s.cachedFields.fields { + isKey := f.Tag.Partition || f.Tag.Sort + for _, i := range f.Tag.Indexes { + isKey = isKey || i.Partition || i.Sort + } + + if !isKey { + continue + } + + at, ok := typeToScalarAttributeType(f.Type) + if ok != true { + continue + } + + s.attributeDefinitions = append(s.attributeDefinitions, types.AttributeDefinition{ + AttributeName: &f.Name, + AttributeType: at, + }) + } + + return nil +} + +// extractIndexes analyzes field index tags and returns mappings for global and local secondary indexes. +// Returns error if index configuration is ambiguous or invalid. +func extractIndexes(fields []Field) (map[string][][]int, map[string][][]int, error) { + globals := make(map[string][][]int) + locals := make(map[string][][]int) + unknowns := make(map[string][][]int) + + // collect index data + for f, fld := range fields { + for i, idx := range fld.Indexes { + if idx.Global && idx.Local { + return nil, nil, fmt.Errorf(`Field "%s" for index "%s" is configured to be both local and global`, fld.Name, idx.Name) + } + + if idx.Partition && idx.Sort { + return nil, nil, fmt.Errorf(`Field "%s" for index "%s" is configured to be both primarty and sort`, fld.Name, idx.Name) + } + + if idx.Partition && idx.Local { + return nil, nil, fmt.Errorf(`Field "%s" for index "%s" is configured to be the primarty key for a local index, local indexes inherit the primary from the table`, fld.Name, idx.Name) + } + + pos := []int{f, i} + + switch { + case idx.Global: + globals[idx.Name] = append(globals[idx.Name], pos) + case idx.Local: + locals[idx.Name] = append(locals[idx.Name], pos) + case !idx.Global && !idx.Local: + unknowns[idx.Name] = append(unknowns[idx.Name], pos) + } + } + } + + for name, positions := range unknowns { + _, gOk := globals[name] + _, lOk := locals[name] + + if gOk && lOk { + return nil, nil, fmt.Errorf(`index "%s" is configured both as global and local secondary index`, name) + } + if !gOk && !lOk { + return nil, nil, fmt.Errorf(`index "%s" type cannot be determined`, name) + } + + if gOk { + globals[name] = append(globals[name], positions...) + } + if lOk { + locals[name] = append(locals[name], positions...) + } + } + + return globals, locals, nil +} + +// resolveSecondaryIndexes processes the schema's fields to determine global and local secondary indexes. +// Populates the schema's index definitions and validates key configurations. +func (s *Schema[T]) resolveSecondaryIndexes() error { + globals, locals, err := extractIndexes(s.cachedFields.fields) + if err != nil { + return err + } + + var tablePrimary *types.KeySchemaElement + if len(s.keySchema) == 0 { + if err := s.resolveKeySchema(); err != nil { + return err + } + } + for _, ks := range s.keySchema { + if ks.KeyType == types.KeyTypeHash { + tablePrimary = &ks + } + } + + if tablePrimary == nil { + return fmt.Errorf("unable to determine the table primary key %v", s.TableName()) + } + + s.localSecondaryIndexes, err = processLSIs(s.cachedFields.fields, *tablePrimary, locals) + if err != nil { + return err + } + + s.globalSecondaryIndexes, err = processGSIs(s.cachedFields.fields, globals) + if err != nil { + return err + } + + return nil +} + +// processGSIs builds GlobalSecondaryIndex definitions from the provided global index mappings. +// Validates that each index has exactly one partition key and at most one sort key. +func processGSIs(fields []Field, globals map[string][][]int) ([]types.GlobalSecondaryIndex, error) { + gs := make([]types.GlobalSecondaryIndex, 0, len(globals)) + + // build globals + for name, positions := range globals { + numPrimaries := 0 + numSorts := 0 + + g := types.GlobalSecondaryIndex{ + IndexName: pointer(name), + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + } + + for _, pos := range positions { + f := fields[pos[0]] + i := f.Indexes[pos[1]] + + switch { + case i.Partition: + g.KeySchema = append(g.KeySchema, types.KeySchemaElement{ + AttributeName: pointer(f.Name), + KeyType: types.KeyTypeHash, + }) + numPrimaries++ + break + case i.Sort: + g.KeySchema = append(g.KeySchema, types.KeySchemaElement{ + AttributeName: pointer(f.Name), + KeyType: types.KeyTypeRange, + }) + numSorts++ + break + default: + g.Projection.NonKeyAttributes = append(g.Projection.NonKeyAttributes, f.Name) + g.Projection.ProjectionType = types.ProjectionTypeInclude + } + + // the hash must be first + if len(g.KeySchema) == 2 { + slices.SortStableFunc(g.KeySchema, ksSortFunc) + } + } + + if numPrimaries != 1 { + return nil, fmt.Errorf(`index "%s" has %d primary keys, it must have exactly 1`, name, numPrimaries) + } + + if numSorts > 1 { + return nil, fmt.Errorf(`index "%s" has %d sort keys, it must have exactly 0 or 1`, name, numSorts) + } + + gs = append(gs, g) + } + + return gs, nil +} + +// ksSortFunc sorts KeySchemaElements so that the hash key appears before the range key. +func ksSortFunc(a, b types.KeySchemaElement) int { + switch types.KeyTypeHash { + case a.KeyType: + return -1 + case b.KeyType: + return 1 + default: + return 0 + } +} + +// processLSIs builds LocalSecondaryIndex definitions from the provided local index mappings. +// Each local index inherits the table's primary key and may define a sort key. +func processLSIs(fields []Field, tablePrimary types.KeySchemaElement, locals map[string][][]int) ([]types.LocalSecondaryIndex, error) { + ls := make([]types.LocalSecondaryIndex, 0, len(locals)) + numSorts := 0 + + for name, positions := range locals { + l := types.LocalSecondaryIndex{ + IndexName: pointer(name), + KeySchema: []types.KeySchemaElement{ + tablePrimary, + }, + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + } + + for _, pos := range positions { + f := fields[pos[0]] + i := f.Indexes[pos[1]] + + switch { + case i.Partition: + return nil, fmt.Errorf(`index "%s" has Field "%s" is configured as the primary key for a secondary index, secondary indexes inherit the primary key of their table`, name, f.Name) + case i.Sort: + l.KeySchema = append(l.KeySchema, types.KeySchemaElement{ + AttributeName: pointer(f.Name), + KeyType: types.KeyTypeRange, + }) + numSorts++ + default: + l.Projection.NonKeyAttributes = append(l.Projection.NonKeyAttributes, f.Name) + l.Projection.ProjectionType = types.ProjectionTypeInclude + } + } + + ls = append(ls, l) + } + + return ls, nil +} diff --git a/feature/dynamodb/entitymanager/schema_resolvers_test.go b/feature/dynamodb/entitymanager/schema_resolvers_test.go new file mode 100644 index 00000000000..7602a05328b --- /dev/null +++ b/feature/dynamodb/entitymanager/schema_resolvers_test.go @@ -0,0 +1,453 @@ +package entitymanager + +import ( + "fmt" + "reflect" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestResolveTableName(t *testing.T) { + cases := []struct { + input any + expected *string + error bool + }{ + { + input: &Schema[order]{}, + expected: pointer("order"), + }, + { + input: &Schema[*order]{}, + expected: pointer("order"), + }, + { + input: &Schema[address]{}, + expected: pointer("address"), + }, + { + input: &Schema[*address]{}, + expected: pointer("address"), + }, + { + input: &Schema[reflect.Value]{}, + expected: pointer("Value"), + }, + { + input: &Schema[*reflect.Value]{}, + expected: pointer("Value"), + }, + { + input: &Schema[any]{}, + expected: nil, + error: true, + }, + { + input: &Schema[string]{}, + expected: nil, + error: true, + }, + { + input: &Schema[[]byte]{}, + expected: nil, + error: true, + }, + { + input: &Schema[[]order]{}, + expected: nil, + error: true, + }, + } + + type tableNameResolver interface { + TableName() *string + resolveTableName() error + } + + for i, c := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + var actual tableNameResolver + var ok bool + + if actual, ok = c.input.(tableNameResolver); !ok && !c.error { + t.Fatalf("unable to check the presence of the resolveTableName() error method") + } + + err := actual.resolveTableName() + + if c.error && err == nil { + t.Fatalf("expected error") + } + + if !c.error && err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if c.error && err != nil { + return + } + + if diff := cmpDiff(c.expected, actual.TableName()); len(diff) > 0 { + t.Errorf(`failed to resolve table name, expected: %s"`, diff) + } + }) + } +} + +func TestResolveKeySchema(t *testing.T) { + pk := "pk" + sk := "sk" + + cases := []struct { + input []Field + expected any + error bool + }{ + { + input: []Field{ + { + Name: "pk", + Tag: Tag{Partition: true}, + }, + { + Name: "sk", + Tag: Tag{Sort: true}, + }, + {}, + {}, + {}, + }, + expected: []types.KeySchemaElement{ + { + AttributeName: &pk, + KeyType: types.KeyTypeHash, + }, + { + AttributeName: &sk, + KeyType: types.KeyTypeRange, + }, + }, + error: false, + }, + { + input: []Field{ + { + Name: "pk", + Tag: Tag{Partition: true}, + }, + {}, + {}, + {}, + }, + expected: []types.KeySchemaElement{ + { + AttributeName: &pk, + KeyType: types.KeyTypeHash, + }, + }, + error: false, + }, + { + input: []Field{ + { + Name: "pk", + Tag: Tag{Partition: true}, + }, + { + Name: "sk", + Tag: Tag{Partition: true}, + }, + {}, + {}, + {}, + }, + expected: []types.KeySchemaElement{}, + error: true, + }, + { + input: []Field{ + { + Name: "pk", + Tag: Tag{Partition: true}, + }, + { + Name: "sk", + Tag: Tag{Sort: true}, + }, + { + Name: "sk", + Tag: Tag{Sort: true}, + }, + {}, + {}, + }, + expected: []types.KeySchemaElement{}, + error: true, + }, + { + input: []Field{ + { + Name: "pk", + }, + { + Name: "sk", + }, + {}, + {}, + }, + expected: []types.KeySchemaElement{}, + error: true, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + o := &Schema[order]{ + cachedFields: &CachedFields{ + fields: c.input, + }, + } + + err := o.resolveKeySchema() + if c.error && err == nil { + t.Fatalf("expected error") + } + + if !c.error && err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if c.error && err != nil { + return + } + + if diff := cmpDiff(c.expected, o.KeySchema()); len(diff) != 0 { + t.Fatalf("unexpected diff: %s", diff) + } + }) + } +} + +func TestResolveAttributeDefinitions(t *testing.T) { + cases := []struct { + input []Field + expected []types.AttributeDefinition + }{ + { + input: []Field{}, + expected: []types.AttributeDefinition(nil), + }, + { + input: []Field{ + { + Name: "pk", + Tag: Tag{ + Partition: true, + }, + Type: reflect.TypeFor[string](), + }, + }, + expected: []types.AttributeDefinition{ + { + AttributeName: pointer("pk"), + AttributeType: types.ScalarAttributeTypeS, + }, + }, + }, + { + input: []Field{ + { + Name: "pk", + Tag: Tag{ + Partition: true, + }, + Type: reflect.TypeFor[int32](), + }, + }, + expected: []types.AttributeDefinition{ + { + AttributeName: pointer("pk"), + AttributeType: types.ScalarAttributeTypeN, + }, + }, + }, + { + input: []Field{ + { + Name: "pk", + Tag: Tag{ + Partition: true, + }, + Type: reflect.TypeFor[[]byte](), + }, + }, + expected: []types.AttributeDefinition{ + { + AttributeName: pointer("pk"), + AttributeType: types.ScalarAttributeTypeB, + }, + }, + }, + { + input: []Field{ + { + Name: "sk", + Tag: Tag{ + Sort: true, + }, + Type: reflect.TypeFor[[]byte](), + }, + }, + expected: []types.AttributeDefinition{ + { + AttributeName: pointer("sk"), + AttributeType: types.ScalarAttributeTypeB, + }, + }, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + s := &Schema[any]{ + cachedFields: &CachedFields{ + fields: c.input, + }, + } + + _ = s.resolveAttributeDefinitions() + + if diff := cmpDiff(c.expected, s.AttributeDefinitions()); len(diff) != 0 { + fmt.Printf("%#+v\n", c.expected) + fmt.Printf("%#+v\n", s.AttributeDefinitions()) + t.Fatalf("unexpected diff: %s", diff) + } + }) + } +} + +func TestResolveSecondaryIndexes(t *testing.T) { + cases := []struct { + input []Field + expectedLSIs []types.LocalSecondaryIndex + expectedGSIs []types.GlobalSecondaryIndex + error bool + }{ + { + error: true, + }, + { + input: []Field{}, + error: true, + }, + { + input: []Field{ + { + Name: "pk", + Tag: Tag{ + Partition: true, + Indexes: []Index{ + { + Name: "gsi1", + Global: true, + Partition: true, + }, + { + Name: "gsi2", + Sort: true, + }, + }, + }, + }, + { + Name: "sk", + Tag: Tag{ + Indexes: []Index{ + { + Name: "gsi1", + Sort: true, + }, + { + Name: "gsi2", + Global: true, + Partition: true, + }, + }, + }, + }, + }, + expectedLSIs: []types.LocalSecondaryIndex(nil), + expectedGSIs: []types.GlobalSecondaryIndex{ + { + IndexName: pointer("gsi1"), + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + KeySchema: []types.KeySchemaElement{ + { + AttributeName: pointer("pk"), + KeyType: types.KeyTypeHash, + }, + { + AttributeName: pointer("sk"), + KeyType: types.KeyTypeRange, + }, + }, + }, + { + IndexName: pointer("gsi2"), + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + KeySchema: []types.KeySchemaElement{ + { + AttributeName: pointer("sk"), + KeyType: types.KeyTypeHash, + }, + { + AttributeName: pointer("pk"), + KeyType: types.KeyTypeRange, + }, + }, + }, + }, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + o := &Schema[order]{ + cachedFields: &CachedFields{ + fields: c.input, + }, + } + + err := o.resolveSecondaryIndexes() + if c.error && err == nil { + t.Fatalf("expected error") + } + + if !c.error && err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if c.error && err != nil { + return + } + + if diff := cmpDiff(c.expectedGSIs, o.GlobalSecondaryIndexes()); len(diff) != 0 { + t.Fatalf("unexpected diff in GSIs: %s", diff) + } + + if diff := cmpDiff(c.expectedLSIs, o.LocalSecondaryIndexes()); len(diff) != 0 { + t.Fatalf("unexpected diff in LSIs: %s", diff) + } + }) + } +} + +func TestResolveGlobalSecondaryIndexUpdates(t *testing.T) { +} diff --git a/feature/dynamodb/entitymanager/schema_test.go b/feature/dynamodb/entitymanager/schema_test.go new file mode 100644 index 00000000000..b772b573d20 --- /dev/null +++ b/feature/dynamodb/entitymanager/schema_test.go @@ -0,0 +1,708 @@ +package entitymanager + +import ( + "encoding/json" + "fmt" + "os" + "reflect" + "slices" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestSchema(t *testing.T) { + actual, err := NewSchema[order]() + if err != nil { + t.Fatalf("NewSchema error: %v", err) + } + + if len(actual.cachedFields.fields) != 18 { + t.Fatalf("expected %d CachedFields, found %d", 18, len(actual.cachedFields.fields)) + } + + // | Index Name | Partition Key | Sort Key | Type | Notes | + // | ------------------- | ------------- | ------------ | ---- | --------------------------- | + // | `CustomerIndex` | `customer_id` | `created_at` | GSI | Already present | + // | `TotalAmountIndex` | `total` | `order_id` | GSI | Useful for order bucketing | + // | `OrderVersionIndex` | `order_id` | `version` | LSI | Good for optimistic locking | + // | `RegionIndex` | `zip` | (none) | GSI | Region-based querying | + // | `NoteIndex` | `note` | (optional) | GSI | Requires public exposure | + expected := &Schema[order]{ + typ: reflect.TypeFor[order](), + tableName: pointer("order"), + billingMode: types.BillingModePayPerRequest, + keySchema: []types.KeySchemaElement{ + {AttributeName: pointer("order_id"), KeyType: types.KeyTypeHash}, + {AttributeName: pointer("created_at"), KeyType: types.KeyTypeRange}, + }, + attributeDefinitions: []types.AttributeDefinition{ + {AttributeName: pointer("order_id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: pointer("created_at"), AttributeType: types.ScalarAttributeTypeN}, + {AttributeName: pointer("customer_id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: pointer("total"), AttributeType: types.ScalarAttributeTypeN}, + {AttributeName: pointer("version"), AttributeType: types.ScalarAttributeTypeN}, + {AttributeName: pointer("zip"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: pointer("note"), AttributeType: types.ScalarAttributeTypeS}, + }, + localSecondaryIndexes: []types.LocalSecondaryIndex{ + { + IndexName: pointer("OrderVersionIndex"), + KeySchema: []types.KeySchemaElement{ + {AttributeName: pointer("order_id"), KeyType: types.KeyTypeHash}, + {AttributeName: pointer("version"), KeyType: types.KeyTypeRange}, + }, + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + }, + }, + globalSecondaryIndexes: []types.GlobalSecondaryIndex{ + { + IndexName: pointer("CustomerIndex"), + KeySchema: []types.KeySchemaElement{ + {AttributeName: pointer("customer_id"), KeyType: types.KeyTypeHash}, + {AttributeName: pointer("created_at"), KeyType: types.KeyTypeRange}, + }, + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + }, + { + IndexName: pointer("TotalAmountIndex"), + KeySchema: []types.KeySchemaElement{ + {AttributeName: pointer("total"), KeyType: types.KeyTypeHash}, + {AttributeName: pointer("order_id"), KeyType: types.KeyTypeRange}, + }, + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + }, + { + IndexName: pointer("RegionIndex"), + KeySchema: []types.KeySchemaElement{ + {AttributeName: pointer("zip"), KeyType: types.KeyTypeHash}, + }, + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + }, + { + IndexName: pointer("NoteIndex"), + KeySchema: []types.KeySchemaElement{ + {AttributeName: pointer("note"), KeyType: types.KeyTypeHash}, + }, + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + }, + }, + cachedFields: &CachedFields{ + fields: []Field{ + { + Name: "order_id", + NameFromTag: true, + Index: []int{0}, + Type: reflect.TypeFor[string](), + Tag: Tag{ + Name: "order_id", + Partition: true, + AutoGenerated: true, + Options: map[string][]string{"autogenerated": {"key"}}, + Indexes: []Index{ + { + Name: "TotalAmountIndex", + Global: true, + Sort: true, + }, + }, + }, + }, + { + Name: "created_at", + NameFromTag: true, + Index: []int{1}, + Type: reflect.TypeFor[int64](), + Tag: Tag{ + Name: "created_at", + Sort: true, + AutoGenerated: true, + Options: map[string][]string{"autogenerated": {"timestamp"}}, + Indexes: []Index{ + { + Name: "CustomerIndex", + Sort: true, + }, + }, + }, + }, + { + Name: "updated_at", + NameFromTag: true, + Index: []int{2}, + Type: reflect.TypeFor[time.Time](), + Tag: Tag{ + Name: "updated_at", + AutoGenerated: true, + Options: map[string][]string{"autogenerated": {"timestamp", "always"}}, + }, + }, + { + Name: "customer_id", + NameFromTag: true, + Index: []int{3}, + Type: reflect.TypeFor[string](), + Tag: Tag{ + Name: "customer_id", + Indexes: []Index{ + { + Name: "CustomerIndex", + Global: true, + Partition: true, + }, + }, + }, + }, + { + Name: "total", + NameFromTag: true, + Index: []int{4}, + Type: reflect.TypeFor[float64](), + Tag: Tag{ + Name: "total", + Indexes: []Index{ + { + Name: "TotalAmountIndex", + Global: true, + Partition: true, + }, + }, + }, + }, + { + Name: "version", + NameFromTag: true, + Index: []int{6}, + Type: reflect.TypeFor[int64](), + Tag: Tag{ + Name: "version", + Version: true, + Indexes: []Index{ + { + Name: "OrderVersionIndex", + Sort: true, + Local: true, + }, + }, + }, + }, + { + Name: "versionString", + NameFromTag: true, + Index: []int{7}, + Type: reflect.TypeFor[string](), + Tag: Tag{ + Name: "versionString", + Version: true, + }, + }, + { + Name: "counter_up", + NameFromTag: true, + Index: []int{8}, + Type: reflect.TypeFor[int64](), + Tag: Tag{ + Name: "counter_up", + AtomicCounter: true, + Options: map[string][]string{ + "atomiccounter": {"start=0", "delta=5"}, + }, + }, + }, + { + Name: "counter_down", + NameFromTag: true, + Index: []int{9}, + Type: reflect.TypeFor[int64](), + Tag: Tag{ + Name: "counter_down", + AtomicCounter: true, + Options: map[string][]string{ + "atomiccounter": {"start=0", "delta=-5"}, + }, + }, + }, + { + Name: "metadata", + NameFromTag: true, + Index: []int{10}, + Type: reflect.TypeFor[map[string]string](), + Tag: Tag{ + Name: "metadata", + }, + }, + { + Name: "street", + NameFromTag: true, + Index: []int{11, 0}, + Type: reflect.TypeFor[string](), + Tag: Tag{ + Name: "street", + }, + }, + { + Name: "city", + NameFromTag: true, + Index: []int{11, 1}, + Type: reflect.TypeFor[string](), + Tag: Tag{ + Name: "city", + }, + }, + { + Name: "zip", + NameFromTag: true, + Index: []int{11, 2}, + Type: reflect.TypeFor[string](), + Tag: Tag{ + Name: "zip", + Indexes: []Index{ + { + Name: "RegionIndex", + Global: true, + Partition: true, + }, + }, + }, + }, + { + Name: "Notes", + Index: []int{12}, + Type: reflect.TypeFor[[]string](), + Tag: Tag{ + PreserveEmpty: true, + AsStrSet: true, + }, + }, + { + Name: "note", + NameFromTag: true, + Index: []int{13}, + Type: reflect.TypeFor[string](), + Tag: Tag{ + Name: "note", + Getter: "Note", + Setter: "SetNote", + Indexes: []Index{ + { + Name: "NoteIndex", + Global: true, + Partition: true, + }, + }, + }, + }, + { + Name: "first_name", + NameFromTag: true, + Index: []int{14}, + Type: reflect.TypeFor[string](), + Tag: Tag{ + Name: "first_name", + }, + }, + { + Name: "last_name", + NameFromTag: true, + Index: []int{15}, + Type: reflect.TypeFor[string](), + Tag: Tag{ + Name: "last_name", + }, + }, + { + Name: "nick_name", + NameFromTag: true, + Index: []int{16}, + Type: reflect.TypeFor[string](), + Tag: Tag{ + Name: "nick_name", + }, + }, + }, + fieldsByName: map[string]int{ + "order_id": 0, + "created_at": 1, + "updated_at": 2, + "customer_id": 3, + "total": 4, + "version": 5, + "versionString": 6, + "counter_up": 7, + "counter_down": 8, + "metadata": 9, + "street": 10, + "city": 11, + "zip": 12, + "Notes": 13, + "note": 14, + "first_name": 15, + "last_name": 16, + "nick_name": 17, + }, + }, + enc: NewEncoder[order](), + dec: NewDecoder[order](), + } + + // reflect.DeepEqual cannot compare function pointers + expected.enc.options.EncodeTime = nil + expected.dec.options.DecodeTime.N = nil + expected.dec.options.DecodeTime.S = nil + actual.enc.options.EncodeTime = nil + actual.dec.options.DecodeTime.N = nil + actual.dec.options.DecodeTime.S = nil + + // stuff must be properly sorted when comparing, otherwise reflect.DeepEqual fails + slices.SortStableFunc(expected.globalSecondaryIndexes, gsiSortFunc) + slices.SortStableFunc(actual.globalSecondaryIndexes, gsiSortFunc) + + for idx := range expected.globalSecondaryIndexes { + slices.SortStableFunc(expected.globalSecondaryIndexes[idx].KeySchema, ksSortFunc) + slices.SortStableFunc(actual.globalSecondaryIndexes[idx].KeySchema, ksSortFunc) + } + + if diff := cmpDiff(expected, actual); len(diff) > 0 { + t.Fatalf("unexpected schema diff: %s", diff) + } +} + +func TestSchemaEncodeDecode(t *testing.T) { + now := time.Now().UTC() + + o := order{ + OrderID: "8488941c-0db5-4ace-a8af-3716f2a883bd", + CreatedAt: 1136239445, + UpdatedAt: now, + CustomerID: "507b0215-8413-4dcb-837b-8eaea3812d51", + TotalAmount: 12.34, + IgnoredField: "ignored", + Version: 0, + VersionString: "0", + CounterUp: 0, + CounterDown: 0, + Metadata: map[string]string{ + "string": "string", + "1": "1", + "1.2": "1.2", + }, + address: address{ + Street: "ba5da75d-9fcc-45bc-b2a3-d0b86c0b5919", + City: "c2b86a82-9623-4007-a1ba-112c013f0719", + Zip: "c5262613-aae4-40be-88f6-788f817dd280", + }, + Notes: []string{ + "50debaa8-313c-4113-b90a-9b722fd56ef6", + "d43fdfbd-91e2-4aa9-a46d-64ccc260eb55", + "d8218acf-13b0-458a-9f9c-eaca10b3c080", + }, + customerNote: "e5dbdcc9-9778-4ef0-90f9-108d1c6f6bf3", + CustomerFirstName: "9bcf2a68-7602-42a5-ac70-fc93b2dc17af", + CustomerLastName: "1e73f306-3362-49da-af74-41e1befff588", + } + + m := map[string]types.AttributeValue{ + "order_id": &types.AttributeValueMemberS{ + Value: "8488941c-0db5-4ace-a8af-3716f2a883bd", + }, + "created_at": &types.AttributeValueMemberN{ + Value: "1136239445", + }, + "updated_at": &types.AttributeValueMemberS{ + Value: now.Format(time.RFC3339Nano), + }, + "customer_id": &types.AttributeValueMemberS{ + Value: "507b0215-8413-4dcb-837b-8eaea3812d51", + }, + "total": &types.AttributeValueMemberN{ + Value: "12.34", + }, + "version": &types.AttributeValueMemberN{ + Value: "0", + }, + "versionString": &types.AttributeValueMemberS{ + Value: "0", + }, + "counter_up": &types.AttributeValueMemberN{ + Value: "0", + }, + "counter_down": &types.AttributeValueMemberN{ + Value: "0", + }, + "metadata": &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "string": &types.AttributeValueMemberS{ + Value: "string", + }, + "1": &types.AttributeValueMemberS{ + Value: "1", + }, + "1.2": &types.AttributeValueMemberS{ + Value: "1.2", + }, + }, + }, + "street": &types.AttributeValueMemberS{ + Value: "ba5da75d-9fcc-45bc-b2a3-d0b86c0b5919", + }, + "city": &types.AttributeValueMemberS{ + Value: "c2b86a82-9623-4007-a1ba-112c013f0719", + }, + "zip": &types.AttributeValueMemberS{ + Value: "c5262613-aae4-40be-88f6-788f817dd280", + }, + "Notes": &types.AttributeValueMemberSS{ + Value: []string{ + "50debaa8-313c-4113-b90a-9b722fd56ef6", + "d43fdfbd-91e2-4aa9-a46d-64ccc260eb55", + "d8218acf-13b0-458a-9f9c-eaca10b3c080", + }, + }, + "note": &types.AttributeValueMemberS{ + Value: "e5dbdcc9-9778-4ef0-90f9-108d1c6f6bf3", + }, + "first_name": &types.AttributeValueMemberS{ + Value: "9bcf2a68-7602-42a5-ac70-fc93b2dc17af", + }, + "last_name": &types.AttributeValueMemberS{ + Value: "1e73f306-3362-49da-af74-41e1befff588", + }, + "nick_name": &types.AttributeValueMemberNULL{ + Value: true, + }, + } + + s, _ := NewSchema[order]() + + // new map vs old map + nm, err := s.Encode(&o) + if err != nil { + t.Fatalf("unexpected error for Encode(): %v", err) + } + + if diff := cmpDiff(m, nm); len(diff) > 0 { + e := json.NewEncoder(os.Stdout) + e.SetIndent("", " ") + _ = e.Encode(m) + _ = e.Encode(nm) + t.Fatalf("new map vs old map: keys have different values: %v", diff) + } + + no, err := s.Decode(m) + if err != nil { + t.Fatalf("unexpected error for Decode(): %v", err) + } + if no == nil { + t.Fatalf("unexpected empty object from Decode()") + } + + // ignored CachedFields will not be populated in the new object + o.IgnoredField = "" + // o is a pointer so we need to compare with &o + if diff := cmpDiff(&o, no); len(diff) != 0 { + t.Fatalf("returned object is different: %s", diff) + } +} + +func TestSchemaTableName(t *testing.T) { + cases := []struct { + input *string + expected *string + skipWithTable bool + }{ + { + input: nil, + expected: pointer("order"), + skipWithTable: true, + }, + { + input: nil, + expected: nil, + }, + { + input: pointer(""), + expected: pointer(""), + }, + { + input: pointer("oRdEr"), + expected: pointer("oRdEr"), + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + s, _ := NewSchema[order]() + if !c.skipWithTable { + s.WithTableName(c.input) + } + + if diff := cmpDiff(c.expected, s.TableName()); len(diff) > 0 { + t.Fatalf("unexpected table name diff: %v", diff) + } + }) + } +} + +func TestSchemaCreateTableInput(t *testing.T) { + s, err := NewSchema[order]() + if err != nil { + t.Fatalf("error building schema: %v", err) + } + + expected := &dynamodb.CreateTableInput{ + TableName: pointer("order"), + KeySchema: []types.KeySchemaElement{ + {AttributeName: pointer("order_id"), KeyType: types.KeyTypeHash}, + {AttributeName: pointer("created_at"), KeyType: types.KeyTypeRange}, + }, + AttributeDefinitions: []types.AttributeDefinition{ + {AttributeName: pointer("order_id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: pointer("created_at"), AttributeType: types.ScalarAttributeTypeN}, + {AttributeName: pointer("customer_id"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: pointer("total"), AttributeType: types.ScalarAttributeTypeN}, + {AttributeName: pointer("version"), AttributeType: types.ScalarAttributeTypeN}, + {AttributeName: pointer("zip"), AttributeType: types.ScalarAttributeTypeS}, + {AttributeName: pointer("note"), AttributeType: types.ScalarAttributeTypeS}, + }, + LocalSecondaryIndexes: []types.LocalSecondaryIndex{ + { + IndexName: pointer("OrderVersionIndex"), + KeySchema: []types.KeySchemaElement{ + {AttributeName: pointer("order_id"), KeyType: types.KeyTypeHash}, + {AttributeName: pointer("version"), KeyType: types.KeyTypeRange}, + }, + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + }, + }, + GlobalSecondaryIndexes: []types.GlobalSecondaryIndex{ + { + IndexName: pointer("CustomerIndex"), + KeySchema: []types.KeySchemaElement{ + {AttributeName: pointer("customer_id"), KeyType: types.KeyTypeHash}, + {AttributeName: pointer("created_at"), KeyType: types.KeyTypeRange}, + }, + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + }, + { + IndexName: pointer("TotalAmountIndex"), + KeySchema: []types.KeySchemaElement{ + {AttributeName: pointer("total"), KeyType: types.KeyTypeHash}, + {AttributeName: pointer("order_id"), KeyType: types.KeyTypeRange}, + }, + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + }, + { + IndexName: pointer("RegionIndex"), + KeySchema: []types.KeySchemaElement{ + {AttributeName: pointer("zip"), KeyType: types.KeyTypeHash}, + }, + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + }, + { + IndexName: pointer("NoteIndex"), + KeySchema: []types.KeySchemaElement{ + {AttributeName: pointer("note"), KeyType: types.KeyTypeHash}, + }, + Projection: &types.Projection{ + NonKeyAttributes: nil, + ProjectionType: types.ProjectionTypeAll, + }, + }, + }, + BillingMode: types.BillingModePayPerRequest, + } + + actual, err := s.createTableInput() + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + slices.SortStableFunc(expected.GlobalSecondaryIndexes, gsiSortFunc) + slices.SortStableFunc(actual.GlobalSecondaryIndexes, gsiSortFunc) + + for idx := range expected.GlobalSecondaryIndexes { + slices.SortStableFunc(expected.GlobalSecondaryIndexes[idx].KeySchema, ksSortFunc) + slices.SortStableFunc(actual.GlobalSecondaryIndexes[idx].KeySchema, ksSortFunc) + } + + if diff := cmpDiff(expected, actual); len(diff) > 0 { + t.Fatalf("unexpected diff: %v", diff) + } +} + +func TestSchemaDescribeTableInput(t *testing.T) { + s, _ := NewSchema[order]() + + expected := &dynamodb.DescribeTableInput{ + TableName: pointer("order"), + } + + actual, err := s.describeTableInput() + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if diff := cmpDiff(expected, actual); len(diff) > 0 { + t.Fatalf("unexpected diff: %v", diff) + } +} + +func TestSchemaDeleteTableInput(t *testing.T) { + s, _ := NewSchema[order]() + + expected := &dynamodb.DeleteTableInput{ + TableName: pointer("order"), + } + + actual, err := s.deleteTableInput() + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if diff := cmpDiff(expected, actual); len(diff) > 0 { + t.Fatalf("unexpected diff: %v", diff) + } +} + +func gsiSortFunc(a, b types.GlobalSecondaryIndex) int { + switch { + case *a.IndexName > *b.IndexName: + return 1 + case *a.IndexName < *b.IndexName: + return -1 + default: + return 0 + } +} diff --git a/feature/dynamodb/entitymanager/shared_test.go b/feature/dynamodb/entitymanager/shared_test.go new file mode 100644 index 00000000000..3db27fa93cd --- /dev/null +++ b/feature/dynamodb/entitymanager/shared_test.go @@ -0,0 +1,426 @@ +package entitymanager + +import ( + "fmt" + "reflect" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +type testTextMarshaler struct { + Foo string +} + +func (t *testTextMarshaler) UnmarshalText(b []byte) error { + if !strings.HasPrefix(string(b), "Foo:") { + return fmt.Errorf(`missing "Foo:" prefix`) + } + + t.Foo = string(b)[len("Foo:"):] + return nil +} + +func (t testTextMarshaler) MarshalText() ([]byte, error) { + return []byte("Foo:" + t.Foo), nil +} + +type testBinarySetStruct struct { + Binarys [][]byte `dynamodbav:",binaryset"` +} +type testNumberSetStruct struct { + Numbers []int `dynamodbav:",numberset"` +} +type testStringSetStruct struct { + Strings []string `dynamodbav:",stringset"` +} + +type testIntAsStringStruct struct { + Value int `dynamodbav:",string"` +} + +type testOmitEmptyStruct struct { + Value string `dynamodbav:",omitempty"` + Value2 *string `dynamodbav:",omitempty"` + Value3 int +} + +type testAliasedString string +type testAliasedStringSlice []string +type testAliasedInt int +type testAliasedIntSlice []int +type testAliasedMap map[string]int +type testAliasedSlice []string +type testAliasedByteSlice []byte +type testAliasedBool bool +type testAliasedBoolSlice []bool + +type testAliasedStruct struct { + Value testAliasedString + Value2 testAliasedInt + Value3 testAliasedMap + Value4 testAliasedSlice + + Value5 testAliasedByteSlice + Value6 []testAliasedInt + Value7 []testAliasedString + + Value8 []testAliasedByteSlice `dynamodbav:",binaryset"` + Value9 []testAliasedInt `dynamodbav:",numberset"` + Value10 []testAliasedString `dynamodbav:",stringset"` + + Value11 testAliasedIntSlice + Value12 testAliasedStringSlice + + Value13 testAliasedBool + Value14 testAliasedBoolSlice + + Value15 map[testAliasedString]string +} + +type testNamedPointer *int + +var testDate, _ = time.Parse(time.RFC3339, "2016-05-03T17:06:26.209072Z") + +var sharedTestCases = map[string]struct { + in types.AttributeValue + actual, expected interface{} + err error +}{ + "binary slice": { + in: &types.AttributeValueMemberB{Value: []byte{48, 49}}, + actual: &[]byte{}, + expected: []byte{48, 49}, + }, + "Binary slice oversized": { + in: &types.AttributeValueMemberB{Value: []byte{48, 49}}, + actual: func() *[]byte { + v := make([]byte, 0, 10) + return &v + }(), + expected: []byte{48, 49}, + }, + "binary slice pointer": { + in: &types.AttributeValueMemberB{Value: []byte{48, 49}}, + actual: func() **[]byte { + v := make([]byte, 0, 10) + v2 := &v + return &v2 + }(), + expected: []byte{48, 49}, + }, + "bool": { + in: &types.AttributeValueMemberBOOL{Value: true}, + actual: new(bool), + expected: true, + }, + "list": { + in: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberN{Value: "123"}, + }}, + actual: &[]int{}, + expected: []int{123}, + }, + "map, interface": { + in: &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "abc": &types.AttributeValueMemberN{Value: "123"}, + }}, + actual: &map[string]int{}, + expected: map[string]int{"abc": 123}, + }, + "map, struct": { + in: &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Abc": &types.AttributeValueMemberN{Value: "123"}, + }}, + actual: &struct{ Abc int }{}, + expected: struct{ Abc int }{Abc: 123}, + }, + "map, struct with tags": { + in: &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "abc": &types.AttributeValueMemberN{Value: "123"}, + }}, + actual: &struct { + Abc int `json:"abc" dynamodbav:"abc"` + }{}, + expected: struct { + Abc int `json:"abc" dynamodbav:"abc"` + }{Abc: 123}, + }, + "number, int": { + in: &types.AttributeValueMemberN{Value: "123"}, + actual: new(int), + expected: 123, + }, + "number, Float": { + in: &types.AttributeValueMemberN{Value: "123.1"}, + actual: new(float64), + expected: float64(123.1), + }, + "null pointer": { + in: &types.AttributeValueMemberNULL{Value: true}, + actual: new(*string), + expected: nil, + }, + "string": { + in: &types.AttributeValueMemberS{Value: "abc"}, + actual: new(string), + expected: "abc", + }, + "empty string": { + in: &types.AttributeValueMemberS{Value: ""}, + actual: new(string), + expected: "", + }, + "binary Set": { + in: &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Binarys": &types.AttributeValueMemberBS{Value: [][]byte{{48, 49}, {50, 51}}}, + }, + }, + actual: &testBinarySetStruct{}, + expected: testBinarySetStruct{Binarys: [][]byte{{48, 49}, {50, 51}}}, + }, + "number Set": { + in: &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Numbers": &types.AttributeValueMemberNS{Value: []string{"123", "321"}}, + }, + }, + actual: &testNumberSetStruct{}, + expected: testNumberSetStruct{Numbers: []int{123, 321}}, + }, + "string Set": { + in: &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Strings": &types.AttributeValueMemberSS{Value: []string{"abc", "efg"}}, + }, + }, + actual: &testStringSetStruct{}, + expected: testStringSetStruct{Strings: []string{"abc", "efg"}}, + }, + "int value as string": { + in: &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Value": &types.AttributeValueMemberS{Value: "123"}, + }, + }, + actual: &testIntAsStringStruct{}, + expected: testIntAsStringStruct{Value: 123}, + }, + "omitempty": { + in: &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Value3": &types.AttributeValueMemberN{Value: "0"}, + }, + }, + actual: &testOmitEmptyStruct{}, + expected: testOmitEmptyStruct{Value: "", Value2: nil, Value3: 0}, + }, + "aliased type": { + in: &types.AttributeValueMemberM{ + Value: map[string]types.AttributeValue{ + "Value": &types.AttributeValueMemberS{Value: "123"}, + "Value2": &types.AttributeValueMemberN{Value: "123"}, + "Value3": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "Key": &types.AttributeValueMemberN{Value: "321"}, + }}, + "Value4": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "1"}, + &types.AttributeValueMemberS{Value: "2"}, + &types.AttributeValueMemberS{Value: "3"}, + }}, + "Value5": &types.AttributeValueMemberB{Value: []byte{0, 1, 2}}, + "Value6": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberN{Value: "1"}, + &types.AttributeValueMemberN{Value: "2"}, + &types.AttributeValueMemberN{Value: "3"}, + }}, + "Value7": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "1"}, + &types.AttributeValueMemberS{Value: "2"}, + &types.AttributeValueMemberS{Value: "3"}, + }}, + "Value8": &types.AttributeValueMemberBS{Value: [][]byte{ + {0, 1, 2}, {3, 4, 5}, + }}, + "Value9": &types.AttributeValueMemberNS{Value: []string{ + "1", + "2", + "3", + }}, + "Value10": &types.AttributeValueMemberSS{Value: []string{ + "1", + "2", + "3", + }}, + "Value11": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberN{Value: "1"}, + &types.AttributeValueMemberN{Value: "2"}, + &types.AttributeValueMemberN{Value: "3"}, + }}, + "Value12": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "1"}, + &types.AttributeValueMemberS{Value: "2"}, + &types.AttributeValueMemberS{Value: "3"}, + }}, + "Value13": &types.AttributeValueMemberBOOL{Value: true}, + "Value14": &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberBOOL{Value: true}, + &types.AttributeValueMemberBOOL{Value: false}, + &types.AttributeValueMemberBOOL{Value: true}, + }}, + "Value15": &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "TestKey": &types.AttributeValueMemberS{Value: "TestElement"}, + }}, + }, + }, + actual: &testAliasedStruct{}, + expected: testAliasedStruct{ + Value: "123", Value2: 123, + Value3: testAliasedMap{ + "Key": 321, + }, + Value4: testAliasedSlice{"1", "2", "3"}, + Value5: testAliasedByteSlice{0, 1, 2}, + Value6: []testAliasedInt{1, 2, 3}, + Value7: []testAliasedString{"1", "2", "3"}, + Value8: []testAliasedByteSlice{ + {0, 1, 2}, + {3, 4, 5}, + }, + Value9: []testAliasedInt{1, 2, 3}, + Value10: []testAliasedString{"1", "2", "3"}, + Value11: testAliasedIntSlice{1, 2, 3}, + Value12: testAliasedStringSlice{"1", "2", "3"}, + Value13: true, + Value14: testAliasedBoolSlice{true, false, true}, + Value15: map[testAliasedString]string{"TestKey": "TestElement"}, + }, + }, + "number named pointer": { + in: &types.AttributeValueMemberN{Value: "123"}, + actual: new(testNamedPointer), + expected: testNamedPointer(aws.Int(123)), + }, + "time.Time": { + in: &types.AttributeValueMemberS{Value: "2016-05-03T17:06:26.209072Z"}, + actual: new(time.Time), + expected: testDate, + }, + "time.Time List": { + in: &types.AttributeValueMemberL{Value: []types.AttributeValue{ + &types.AttributeValueMemberS{Value: "2016-05-03T17:06:26.209072Z"}, + &types.AttributeValueMemberS{Value: "2016-05-04T17:06:26.209072Z"}, + }}, + actual: new([]time.Time), + expected: []time.Time{testDate, testDate.Add(24 * time.Hour)}, + }, + "time.Time struct": { + in: &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "abc": &types.AttributeValueMemberS{Value: "2016-05-03T17:06:26.209072Z"}, + }}, + actual: &struct { + Abc time.Time `json:"abc" dynamodbav:"abc"` + }{}, + expected: struct { + Abc time.Time `json:"abc" dynamodbav:"abc"` + }{Abc: testDate}, + }, + "time.Time pointer struct": { + in: &types.AttributeValueMemberM{Value: map[string]types.AttributeValue{ + "abc": &types.AttributeValueMemberS{Value: "2016-05-03T17:06:26.209072Z"}, + }}, + actual: &struct { + Abc *time.Time `json:"abc" dynamodbav:"abc"` + }{}, + expected: struct { + Abc *time.Time `json:"abc" dynamodbav:"abc"` + }{Abc: &testDate}, + }, +} + +var sharedListTestCases = map[string]struct { + in []types.AttributeValue + actual, expected interface{} + err error +}{ + "union members": { + in: []types.AttributeValue{ + &types.AttributeValueMemberB{Value: []byte{48, 49}}, + &types.AttributeValueMemberBOOL{Value: true}, + &types.AttributeValueMemberN{Value: "123"}, + &types.AttributeValueMemberS{Value: "123"}, + }, + actual: func() *[]interface{} { + v := []interface{}{} + return &v + }(), + expected: []interface{}{[]byte{48, 49}, true, 123., "123"}, + }, + "numbers": { + in: []types.AttributeValue{ + &types.AttributeValueMemberN{Value: "1"}, + &types.AttributeValueMemberN{Value: "2"}, + &types.AttributeValueMemberN{Value: "3"}, + }, + actual: &[]interface{}{}, + expected: []interface{}{1., 2., 3.}, + }, +} + +var sharedMapTestCases = map[string]struct { + in map[string]types.AttributeValue + actual, expected interface{} + err error +}{ + "union members": { + in: map[string]types.AttributeValue{ + "B": &types.AttributeValueMemberB{Value: []byte{48, 49}}, + "BOOL": &types.AttributeValueMemberBOOL{Value: true}, + "N": &types.AttributeValueMemberN{Value: "123"}, + "S": &types.AttributeValueMemberS{Value: "123"}, + }, + actual: &map[string]interface{}{}, + expected: map[string]interface{}{ + "B": []byte{48, 49}, "BOOL": true, + "N": 123., "S": "123", + }, + }, +} + +func assertConvertTest(t *testing.T, actual, expected interface{}, err, expectedErr error) { + t.Helper() + + if expectedErr != nil { + if err != nil { + if e, a := expectedErr, err; !strings.Contains(a.Error(), e.Error()) { + t.Errorf("expect %v, got %v", e, a) + } + } else { + t.Fatalf("expected error, %v", expectedErr) + } + } else if err != nil { + t.Fatalf("expect no error, got %v", err) + } else { + if diff := cmpDiff(ptrToValue(expected), ptrToValue(actual)); len(diff) != 0 { + t.Errorf("expect match\n%s", diff) + } + } +} + +func ptrToValue(in interface{}) interface{} { + v := reflect.ValueOf(in) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if !v.IsValid() { + return nil + } + if v.Kind() == reflect.Ptr { + return ptrToValue(v.Interface()) + } + return v.Interface() +} diff --git a/feature/dynamodb/entitymanager/structs_test.go b/feature/dynamodb/entitymanager/structs_test.go new file mode 100644 index 00000000000..de51546e8fd --- /dev/null +++ b/feature/dynamodb/entitymanager/structs_test.go @@ -0,0 +1,42 @@ +package entitymanager + +import ( + "strings" + "time" +) + +type order struct { + OrderID string `dynamodbav:"order_id,partition,autogenerated|key" dynamodbindex:"TotalAmountIndex,global,sort"` + CreatedAt int64 `dynamodbav:"created_at,sort,autogenerated|timestamp" dynamodbindex:"CustomerIndex,sort"` + UpdatedAt time.Time `dynamodbav:"updated_at,autogenerated|timestamp|always"` + CustomerID string `dynamodbav:"customer_id" dynamodbindex:"CustomerIndex,global,partition"` + TotalAmount float64 `dynamodbav:"total" dynamodbindex:"TotalAmountIndex,global,partition"` + IgnoredField string `dynamodbav:"-"` + Version int64 `dynamodbav:"version,version" dynamodbindex:"OrderVersionIndex,local,sort"` + VersionString string `dynamodbav:"versionString,version"` + CounterUp int64 `dynamodbav:"counter_up,atomiccounter|start=0|delta=5"` + CounterDown int64 `dynamodbav:"counter_down,atomiccounter|start=0|delta=-5"` + Metadata map[string]string `dynamodbav:"metadata"` + address + Notes []string `dynamodbav:",preserveempty,stringset"` + customerNote string `dynamodbav:"note" dynamodbgetter:"Note" dynamodbsetter:"SetNote" dynamodbindex:"NoteIndex,global,partition"` + CustomerFirstName string `dynamodbav:"first_name"` + CustomerLastName string `dynamodbav:"last_name"` + CustomerNickName *string `dynamodbav:"nick_name"` +} + +// Getter method for customerNote +func (o *order) Note() string { + return o.customerNote +} + +// Setter method for customerNote +func (o *order) SetNote(note string) { + o.customerNote = strings.TrimSpace(note) +} + +type address struct { + Street string `dynamodbav:"street"` + City string `dynamodbav:"city"` + Zip string `dynamodbav:"zip" dynamodbindex:"RegionIndex,global,partition"` +} diff --git a/feature/dynamodb/entitymanager/table.go b/feature/dynamodb/entitymanager/table.go new file mode 100644 index 00000000000..8204be4fa3e --- /dev/null +++ b/feature/dynamodb/entitymanager/table.go @@ -0,0 +1,151 @@ +package entitymanager + +import ( + "fmt" + "reflect" + + awsmiddleware "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" +) + +// TableOptions provides configuration options for a DynamoDB Table. +// +// T is the type of the items stored in the table. +type TableOptions[T any] struct { + // Schema defines the schema for the table, including attribute mapping and validation. + Schema *Schema[T] + + // DynamoDBOptions is a list of functions to customize the underlying DynamoDB client options. + DynamoDBOptions []func(*dynamodb.Options) + + // ExtensionRegistry holds the registry of extensions to be used with the table. + ExtensionRegistry *ExtensionRegistry[T] + + // MaxConsecutiveErrors sets the maximum number of consecutive errors allowed during batch, query, or scan operations. + // If this threshold is exceeded, the operation will stop and return. + // If set to 0, the default value of DefaultMaxConsecutiveErrors will be used. + MaxConsecutiveErrors uint +} + +// DefaultMaxConsecutiveErrors is the fallback value used for MaxConsecutiveErrors when it is set to 0. +// A value of 1 means the operation will stop after the first error. +const DefaultMaxConsecutiveErrors uint = 1 + +// Table represents a strongly-typed DynamoDB table for items of type T. +// +// It provides methods for interacting with DynamoDB using the provided client and options. +type Table[T any] struct { + // client is the DynamoDB client used to perform operations on the table. + client Client + + // options holds the configuration options for the table. + options TableOptions[T] +} + +// NewTable creates a new Table for items of type T using the provided client and configuration functions. +// +// The configuration functions can be used to customize the TableOptions before the table is created. +// Returns an error if T is not a struct type or if required options cannot be resolved. +func NewTable[T any](client Client, fns ...func(options *TableOptions[T])) (*Table[T], error) { + if reflect.TypeFor[T]().Kind() != reflect.Struct { + return nil, fmt.Errorf("NewClient() can only be created from structs, %T given", *new(T)) + } + + opts := TableOptions[T]{} + + if c, ok := client.(*dynamodb.Client); ok { + client = dynamodb.New(c.Options(), func(o *dynamodb.Options) { + o.APIOptions = append(o.APIOptions, awsmiddleware.AddUserAgentKeyValue(UserAgentPart, EntityManagerVersion)) + }) + } + + for _, fn := range fns { + fn(&opts) + } + + defaultResolvers := []resolverFn[T]{ + resolveDefaultSchema[T], + resolveDefaultExtensionRegistry[T], + resolveDefaultMaxConsecutiveErrors[T], + } + + for _, fn := range defaultResolvers { + if err := fn(&opts); err != nil { + return nil, err + } + } + + return &Table[T]{ + client: client, + options: opts, + }, nil +} + +// WithSchema returns a configuration function that sets the Schema for TableOptions. +// +// Use this to specify a custom schema when creating a Table. +func WithSchema[T any](schema *Schema[T]) func(options *TableOptions[T]) { + return func(options *TableOptions[T]) { + options.Schema = schema + } +} + +// WithExtensionRegistry returns a configuration function that sets the ExtensionRegistry for TableOptions. +// +// Use this to specify a custom extension registry when creating a Table. +func WithExtensionRegistry[T any](registry *ExtensionRegistry[T]) func(options *TableOptions[T]) { + return func(options *TableOptions[T]) { + options.ExtensionRegistry = registry + } +} + +// WithMaxConsecutiveErrors returns a configuration function that sets the MaxConsecutiveErrors option for TableOptions. +// +// Use this to specify the maximum number of consecutive errors allowed during batch, query, or scan operations. +// A value of 0 means no limit is enforced. +// WithMaxConsecutiveErrors returns a configuration function that sets the MaxConsecutiveErrors option for TableOptions. +// +// Use this to specify the maximum number of consecutive errors allowed during batch, query, or scan operations. +// If set to 0, the default value of DefaultMaxConsecutiveErrors will be used. +func WithMaxConsecutiveErrors[T any](maxConsecutiveErrors uint) func(options *TableOptions[T]) { + return func(options *TableOptions[T]) { + options.MaxConsecutiveErrors = maxConsecutiveErrors + } +} + +// resolverFn defines a function type for resolving or setting default options on TableOptions. +type resolverFn[T any] func(opts *TableOptions[T]) error + +// resolveDefaultSchema sets a default schema on TableOptions if none is provided. +// +// Returns an error if the schema cannot be created. +func resolveDefaultSchema[T any](opts *TableOptions[T]) error { + if opts.Schema == nil { + var err error + opts.Schema, err = NewSchema[T]() + if err != nil { + return err + } + } + + return nil +} + +// resolveDefaultExtensionRegistry sets a default extension registry on TableOptions if none is provided. +func resolveDefaultExtensionRegistry[T any](opts *TableOptions[T]) error { + if opts.ExtensionRegistry == nil { + opts.ExtensionRegistry = DefaultExtensionRegistry[T]() + } + + return nil +} + +// resolveDefaultMaxConsecutiveErrors sets MaxConsecutiveErrors to DefaultMaxConsecutiveErrors +// if it is not explicitly set (i.e., if the value is 0). +// This ensures a sensible default for error handling in batch, query, or scan operations. +func resolveDefaultMaxConsecutiveErrors[T any](opts *TableOptions[T]) error { + if opts.MaxConsecutiveErrors == 0 { + opts.MaxConsecutiveErrors = DefaultMaxConsecutiveErrors + } + return nil +} diff --git a/feature/dynamodb/entitymanager/table_extension.go b/feature/dynamodb/entitymanager/table_extension.go new file mode 100644 index 00000000000..85066dce242 --- /dev/null +++ b/feature/dynamodb/entitymanager/table_extension.go @@ -0,0 +1,174 @@ +package entitymanager + +import ( + "context" + "fmt" + "reflect" + + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" +) + +func (t *Table[T]) createExtensionContext() context.Context { + ctx := context.Background() + if t.options.Schema != nil { + ctx = context.WithValue(ctx, TableSchemaKey{}, t.options.Schema) + ctx = context.WithValue(ctx, CachedFieldsKey{}, t.options.Schema.cachedFields) + } else { + ctx = context.WithValue(ctx, TableSchemaKey{}, (*Schema[T])(nil)) + ctx = context.WithValue(ctx, CachedFieldsKey{}, (*CachedFields)(nil)) + } + + return ctx +} + +func (t *Table[T]) createUpdateExpression(v *T) (expression.Expression, error) { + empty := expression.Expression{} + + if t.options.Schema == nil || t.options.Schema.cachedFields == nil { + return empty, fmt.Errorf("empty schema or schema cache fields for Table[%T]", *new(T)) + } + + ctx := t.createExtensionContext() + + var conditionBuilder *expression.ConditionBuilder + var filterBuilder *expression.ConditionBuilder + var keyConditionBuilder *expression.KeyConditionBuilder + var projectionBuilder *expression.ProjectionBuilder + updateBuilder := &expression.UpdateBuilder{} + + r := reflect.ValueOf(v) + + for _, f := range t.options.Schema.cachedFields.All() { + // skip items that will be autogenerated or pk or sk + if f.AutoGenerated || f.Version || f.AtomicCounter || f.Partition || f.Sort { + continue + } + + var cv reflect.Value + if f.Tag.Getter != "" { + cv = r.MethodByName(f.Tag.Getter). + Call([]reflect.Value{})[0] + } else { + cv = r.Elem().FieldByIndex(f.Index) + } + + *updateBuilder = updateBuilder.Set(expression.Name(f.Name), expression.Value(cv.Interface())) + } + + for _, e := range t.options.ExtensionRegistry.beforeWriters { + if b, ok := e.(ConditionExpressionBuilder[T]); ok { + if err := b.BuildCondition(ctx, v, &conditionBuilder); err != nil { + return empty, fmt.Errorf("error during ConditionExpressionBuilder[%T]: %v", e, err) + } + } + if b, ok := e.(FilterExpressionBuilder[T]); ok { + if err := b.BuildFilter(ctx, v, &filterBuilder); err != nil { + return empty, fmt.Errorf("error during FilterExpressionBuilder[%T]: %v", e, err) + } + } + if b, ok := e.(KeyConditionBuilder[T]); ok { + if err := b.BuildKeyCondition(ctx, v, &keyConditionBuilder); err != nil { + return empty, fmt.Errorf("error during KeyConditionBuilder[%T]: %v", e, err) + } + } + if b, ok := e.(ProjectionExpressionBuilder[T]); ok { + if err := b.BuildProjection(ctx, v, &projectionBuilder); err != nil { + return empty, fmt.Errorf("error during ProjectionExpressionBuilder[%T]: %v", e, err) + } + } + if b, ok := e.(UpdateExpressionBuilder[T]); ok { + if err := b.BuildUpdate(ctx, v, &updateBuilder); err != nil { + return empty, fmt.Errorf("error during UpdateExpressionBuilder[%T]: %v", e, err) + } + } + } + + builder := expression.NewBuilder() + if conditionBuilder != nil { + builder = builder.WithCondition(*conditionBuilder) + } + if filterBuilder != nil { + builder = builder.WithFilter(*filterBuilder) + } + if keyConditionBuilder != nil { + builder = builder.WithKeyCondition(*keyConditionBuilder) + } + if projectionBuilder != nil { + builder = builder.WithProjection(*projectionBuilder) + } + if updateBuilder != nil { + builder = builder.WithUpdate(*updateBuilder) + } + + return builder.Build() +} + +func (t *Table[T]) applyBeforeReadExtensions(v *T) error { + if t.options.ExtensionRegistry == nil { + return nil + } + + ctx := t.createExtensionContext() + + for _, br := range t.options.ExtensionRegistry.beforeReaders { + if err := br.BeforeRead(ctx, v); err != nil { + return fmt.Errorf("error during applyBeforeReadExtensions %T: %v", br, err) + } + } + + return nil +} + +func (t *Table[T]) applyAfterReadExtensions(v *T) error { + if t.options.ExtensionRegistry == nil { + return nil + } + + ctx := t.createExtensionContext() + + for _, ar := range t.options.ExtensionRegistry.afterReaders { + if err := ar.AfterRead(ctx, v); err != nil { + return fmt.Errorf("error during applyAfterReadExtensions %T: %v", ar, err) + } + } + + return nil +} + +func (t *Table[T]) applyBeforeWriteExtensions(v *T) error { + if t.options.ExtensionRegistry == nil { + return nil + } + + ctx := t.createExtensionContext() + + for _, bw := range t.options.ExtensionRegistry.beforeWriters { + if err := bw.BeforeWrite(ctx, v); err != nil { + return fmt.Errorf("error during applyBeforeWriteExtensions %T: %v", bw, err) + } + } + + return nil +} + +func (t *Table[T]) applyAfterWriteExtensions(v *T) error { + if t.options.ExtensionRegistry == nil { + return nil + } + + ctx := t.createExtensionContext() + + for _, aw := range t.options.ExtensionRegistry.afterWriters { + if err := aw.AfterWrite(ctx, v); err != nil { + return fmt.Errorf("error during applyBeforeWriteExtensions %T: %v", aw, err) + } + } + + return nil +} + +// @TODO: implement when adding Scan() and Query() +// func (t *Table[T]) applyBeforeScannersExtensions(v *T) error { return nil } +// func (t *Table[T]) applyAfterScannersExtensions(v *T) error { return nil } +// func (t *Table[T]) applyBeforeQueriersExtensions(v *T) error { return nil } +// func (t *Table[T]) applyAfterQueriersExtensions(v []T) error { return nil } diff --git a/feature/dynamodb/entitymanager/table_extension_test.go b/feature/dynamodb/entitymanager/table_extension_test.go new file mode 100644 index 00000000000..a3892995548 --- /dev/null +++ b/feature/dynamodb/entitymanager/table_extension_test.go @@ -0,0 +1,300 @@ +package entitymanager + +import ( + "context" + "strconv" + "testing" + + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" +) + +func TestTableCreateExtensionContext(t *testing.T) { + type extensionContextCreator interface { + createExtensionContext() context.Context + } + + makeContext := func(args ...any) context.Context { + ctx := context.Background() + + for c := range len(args) / 2 { + ctx = context.WithValue(ctx, args[c*2], args[c*2+1]) + } + + return ctx + } + + keys := []any{ + CachedFieldsKey{}, + TableSchemaKey{}, + } + + cases := []struct { + source extensionContextCreator + expected context.Context + }{ + { + source: &Table[any]{}, + expected: makeContext( + CachedFieldsKey{}, + (*CachedFields)(nil), + TableSchemaKey{}, + (*Schema[any])(nil), + ), + }, + { + source: &Table[any]{ + options: TableOptions[any]{ + Schema: &Schema[any]{ + cachedFields: &CachedFields{ + fields: []Field{ + { + Tag: Tag{}, + Name: "", + NameFromTag: false, + Index: nil, + Type: nil, + }, + }, + fieldsByName: map[string]int{ + "": 0, + }, + }, + }, + }, + }, + expected: makeContext( + CachedFieldsKey{}, + &CachedFields{ + fields: []Field{ + { + Tag: Tag{}, + Name: "", + NameFromTag: false, + Index: nil, + Type: nil, + }, + }, + fieldsByName: map[string]int{ + "": 0, + }, + }, + TableSchemaKey{}, + &Schema[any]{ + cachedFields: &CachedFields{ + fields: []Field{ + { + Tag: Tag{}, + Name: "", + NameFromTag: false, + Index: nil, + Type: nil, + }, + }, + fieldsByName: map[string]int{ + "": 0, + }, + }, + }, + ), + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + actual := c.source.createExtensionContext() + + for _, k := range keys { + if diff := cmpDiff(c.expected.Value(k), actual.Value(k)); len(diff) != 0 { + t.Fatalf("unexpected diff: %s", diff) + } + } + }) + } +} + +type testStruct struct { + beforeWrite int + afterRead int +} + +type testExtension struct{} + +func (t *testExtension) IsExtension() {} + +func (t *testExtension) BeforeRead(ctx context.Context, v *testStruct) error { + return nil +} + +func (t *testExtension) AfterRead(ctx context.Context, v *testStruct) error { + return nil +} + +func (t *testExtension) BeforeWrite(ctx context.Context, v *testStruct) error { + return nil +} + +func (t *testExtension) AfterWrite(ctx context.Context, v *testStruct) error { + return nil +} + +func (t *testExtension) BuildCondition(context.Context, *testStruct, **expression.ConditionBuilder) error { + return nil +} +func (t *testExtension) BuildFilter(context.Context, *testStruct, **expression.ConditionBuilder) error { + return nil +} +func (t *testExtension) BuildKeyCondition(context.Context, *testStruct, **expression.KeyConditionBuilder) error { + return nil +} +func (t *testExtension) BuildProjection(context.Context, *testStruct, **expression.ProjectionBuilder) error { + return nil +} +func (t *testExtension) BuildUpdate(context.Context, *testStruct, **expression.UpdateBuilder) error { + return nil +} + +//func TestApplyBeforeWriteExtensions(t *testing.T) { +// cases := []struct { +// when ExecutionPhase +// extensions []Extension +// input any +// expected any +// error bool +// }{ +// { +// extensions: []Extension{}, +// error: false, +// }, +// } +// +// for i, c := range cases { +// t.Run(strconv.Itoa(i), func(t *testing.T) { +// str := testStruct{} +// sch := &Table[testStruct]{ +// options: TableOptions[testStruct]{ +// ExtensionRegistry: &ExtensionRegistry[testStruct]{ +// beforeWriters: []BeforeWriter[testStruct]{}, +// }, +// }, +// +// //extensions: map[ExecutionPhase][]Extension{ +// // BeforeWrite: c.extensions, +// //}, +// } +// err := sch.applyAfterReadExtensions(&str) +// +// if !c.error && err != nil { +// t.Errorf("unexpected error: %v", err) +// +// return +// } +// +// if c.error && err == nil { +// t.Error("expected error") +// +// return +// } +// +// //if diff := cmpDiff(c.expected, av); len(diff) != 0 { +// // t.Errorf("unexpected diff: %s", diff) +// //} +// _ = c +// }) +// } +// _ = cases +//} + +//func TestSchemaApplyExtension(t *testing.T) { +// if true { +// return +// } +// cases := []struct { +// when ExecutionPhase +// actual map[string]types.AttributeValue +// expected map[string]types.AttributeValue +// error bool +// }{ +// { +// when: BeforeWrite, +// actual: map[string]types.AttributeValue{}, +// expected: map[string]types.AttributeValue{ +// "id": &types.AttributeValueMemberS{ +// Value: "", +// }, +// }, +// }, +// } +// +// buff := bytes.Buffer{} +// cryptorand.Reader = io.TeeReader(cryptorand.Reader, &buff) +// +// s, _ := NewSchema[order]() +// +// for i, c := range cases { +// t.Run(strconv.Itoa(i), func(t *testing.T) { +// t.Logf("buffer too big: %d", len(buff.Bytes())) +// buff.Reset() +// t.Logf("buffer too big: %d", len(buff.Bytes())) +// +// actual, _ := s.Decode(c.actual) +// +// var err error +// switch c.when { +// case BeforeWrite: +// err = s.applyBeforeWriteExtensions(actual) +// case AfterRead: +// err = s.applyAfterReadExtensions(actual) +// //case BeforeQuery: +// // err = s.apply(actual) +// //case BeforeScan: +// // err = s.applyBeforeWriteExtensions(actual) +// default: +// t.Fatalf("i don't know how to handle: %s", c.when) +// } +// //err := s.applyExtensions(BeforeWrite, actual) +// +// t.Logf("buffer too big: %d", len(buff.Bytes())) +// +// fmt.Printf("%#+v\n", actual) +// +// b := buff.Bytes()[:] +// if len(b) != 16 { +// t.Fatalf("buffer too big: %d", len(b)) +// } +// +// //b[6] = (b[6] & 0x0f) | 0x40 +// //b[8] = (b[8] & 0x3f) | 0x80 +// //c.expected["id"] = &types.AttributeValueMemberS{ +// // Value: fmt.Sprintf( +// // "%x-%x-%x-%x-%x", +// // b[0:4], +// // b[4:6], +// // b[6:8], +// // b[8:10], +// // b[10:16], +// // ), +// //} +// //fmt.Println(c.expected["id"].(*types.AttributeValueMemberS).Value) +// +// if !c.error && err != nil { +// t.Errorf("unexpected error: %v", err) +// +// return +// } +// +// if c.error && err == nil { +// t.Error("expected error") +// +// return +// } +// +// if diff := cmpDiff(c.expected, c.actual); len(diff) != 0 { +// e := json.NewEncoder(os.Stdout) +// e.SetIndent("", " ") +// _ = e.Encode(c.actual) +// _ = e.Encode(c.expected) +// t.Errorf("unexpected diff: %s", diff) +// } +// }) +// } +//} diff --git a/feature/dynamodb/entitymanager/table_item_operations.go b/feature/dynamodb/entitymanager/table_item_operations.go new file mode 100644 index 00000000000..ea1d07b779e --- /dev/null +++ b/feature/dynamodb/entitymanager/table_item_operations.go @@ -0,0 +1,383 @@ +package entitymanager + +import ( + "context" + "fmt" + "iter" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +// GetItem retrieves a single item from the DynamoDB table by its key. +// Returns the decoded item or an error if not found or decoding fails. +func (t *Table[T]) GetItem(ctx context.Context, m Map, optFns ...func(*dynamodb.Options)) (*T, error) { + res, err := t.client.GetItem(ctx, &dynamodb.GetItemInput{ + TableName: t.options.Schema.TableName(), + Key: m, + }, optFns...) + if err != nil { + return nil, err + } + + if res == nil || res.Item == nil { + return nil, fmt.Errorf("empty response or item in GetItem() call") + } + + item, err := t.options.Schema.Decode(res.Item) + if err != nil { + return nil, err + } + + err = t.applyAfterReadExtensions(item) + if err != nil { + return nil, err + } + + return item, nil +} + +// GetItemWithProjection retrieves a single item from the DynamoDB table by its key, applying a projection to select specific attributes. +// Returns the decoded item or an error if not found or decoding fails. +func (t *Table[T]) GetItemWithProjection(ctx context.Context, m Map, proj expression.ProjectionBuilder, optFns ...func(*dynamodb.Options)) (*T, error) { + b, err := expression.NewBuilder().WithProjection(proj).Build() + if err != nil { + return nil, err + } + + res, err := t.client.GetItem(ctx, &dynamodb.GetItemInput{ + TableName: t.options.Schema.TableName(), + Key: m, + ExpressionAttributeNames: b.Names(), + ProjectionExpression: b.Projection(), + }, optFns...) + if err != nil { + return nil, err + } + + if res == nil || res.Item == nil { + return nil, fmt.Errorf("empty response or item in GetItemWithProjection() call") + } + + item, err := t.options.Schema.Decode(res.Item) + if err != nil { + return nil, err + } + + err = t.applyAfterReadExtensions(item) + if err != nil { + return nil, err + } + + return item, nil +} + +// PutItem writes the item to the DynamoDB table without checking for collisions. +// Returns the written item or an error if encoding or writing fails. +func (t *Table[T]) PutItem(ctx context.Context, item *T, optFns ...func(*dynamodb.Options)) (*T, error) { + err := t.applyBeforeWriteExtensions(item) + if err != nil { + return nil, err + } + + itemMap, err := t.options.Schema.Encode(item) + if err != nil { + return nil, err + } + + res, err := t.client.PutItem(ctx, &dynamodb.PutItemInput{ + TableName: t.options.Schema.TableName(), + Item: itemMap, + }, optFns...) + if err != nil { + return nil, err + } + if res == nil { + return nil, fmt.Errorf("empty response in PutItem() call") + } + + out, err := t.options.Schema.Decode(itemMap) + if err != nil { + return nil, err + } + + if err := t.applyAfterWriteExtensions(out); err != nil { + return nil, err + } + + return out, nil +} + +// UpdateItem writes the item to the DynamoDB table with additional checks (e.g., version checks). +// Returns the updated item or an error if encoding or updating fails. +func (t *Table[T]) UpdateItem(ctx context.Context, item *T, optFns ...func(*dynamodb.Options)) (*T, error) { + err := t.applyBeforeWriteExtensions(item) + if err != nil { + return nil, err + } + + m, err := t.options.Schema.createKeyMap(item) + if err != nil { + return nil, err + } + + expr, err := t.createUpdateExpression(item) + if err != nil { + return nil, err + } + + res, err := t.client.UpdateItem(ctx, &dynamodb.UpdateItemInput{ + TableName: t.options.Schema.TableName(), + Key: m, + ConditionExpression: expr.Condition(), + ExpressionAttributeNames: expr.Names(), + ExpressionAttributeValues: expr.Values(), + UpdateExpression: expr.Update(), + ReturnValues: types.ReturnValueAllNew, + ReturnValuesOnConditionCheckFailure: types.ReturnValuesOnConditionCheckFailureAllOld, + }, optFns...) + if err != nil { + return nil, err + } + if res == nil { + return nil, fmt.Errorf("empty response in UpdateItem() call") + } + + out, err := t.options.Schema.Decode(res.Attributes) + if err != nil { + return nil, err + } + + if err := t.applyAfterWriteExtensions(out); err != nil { + return nil, err + } + + return out, nil +} + +// DeleteItem deletes an item from the DynamoDB table by its struct value. +// Returns an error if the key cannot be created or the delete fails. +func (t *Table[T]) DeleteItem(ctx context.Context, item *T, optFns ...func(*dynamodb.Options)) error { + m, err := t.options.Schema.createKeyMap(item) + if err != nil { + return err + } + + _, err = t.client.DeleteItem(ctx, &dynamodb.DeleteItemInput{ + TableName: t.options.Schema.TableName(), + Key: m, + }, optFns...) + + return err +} + +// DeleteItemByKey deletes an item from the DynamoDB table by its key map. +// Returns an error if the delete fails. +func (t *Table[T]) DeleteItemByKey(ctx context.Context, m Map, optFns ...func(*dynamodb.Options)) error { + _, err := t.client.DeleteItem(ctx, &dynamodb.DeleteItemInput{ + TableName: t.options.Schema.TableName(), + Key: m, + }, optFns...) + + return err +} + +// createScanIterator returns an iterator that scans a DynamoDB table or index and yields results as ItemResult[*T]. +// It automatically handles pagination and error thresholds using MaxConsecutiveErrors. +// If the number of consecutive errors reaches the threshold, iteration stops. +func (t Table[T]) createScanIterator(ctx context.Context, indexName *string, expr expression.Expression, optFns ...func(*dynamodb.Options)) iter.Seq[ItemResult[*T]] { + var consecutiveErrors uint = 0 + var maxConsecutiveErrors = t.options.MaxConsecutiveErrors + if maxConsecutiveErrors == 0 { + maxConsecutiveErrors = DefaultMaxConsecutiveErrors + } + + return func(yield func(ItemResult[*T]) bool) { + var lastEvaluatedKey map[string]types.AttributeValue + + for { + scanInput := &dynamodb.ScanInput{ + TableName: t.options.Schema.TableName(), + IndexName: indexName, + ConsistentRead: aws.Bool(indexName == nil), + ExclusiveStartKey: lastEvaluatedKey, + Select: types.SelectAllAttributes, + FilterExpression: expr.Filter(), + ProjectionExpression: expr.Projection(), + ExpressionAttributeNames: expr.Names(), + ExpressionAttributeValues: expr.Values(), + } + + res, err := t.client.Scan(ctx, scanInput, optFns...) + if err != nil { + consecutiveErrors++ + + if !yield(ItemResult[*T]{err: err, table: *t.options.Schema.TableName()}) { + return + } + + if consecutiveErrors >= maxConsecutiveErrors { + return + } + + continue + } + + consecutiveErrors = 0 + + if res != nil && res.Items != nil { + for _, item := range res.Items { + i, err := t.options.Schema.Decode(item) + if err != nil { + if !yield(ItemResult[*T]{err: err, table: *t.options.Schema.TableName()}) { + return + } + + continue + } + + if err := t.applyAfterReadExtensions(i); err != nil { + if !yield(ItemResult[*T]{err: err, table: *t.options.Schema.TableName()}) { + return + } + + continue + } + + if !yield(ItemResult[*T]{item: i, table: *t.options.Schema.TableName()}) { + return + } + } + + lastEvaluatedKey = res.LastEvaluatedKey + } else { + lastEvaluatedKey = nil + } + + if lastEvaluatedKey == nil { + return + } + } + } +} + +// ScanIndex scans a DynamoDB index and returns an iterator of results. +// The scan uses the provided index name and expression. +func (t *Table[T]) ScanIndex(ctx context.Context, indexName string, expr expression.Expression, optFns ...func(*dynamodb.Options)) iter.Seq[ItemResult[*T]] { + return t.createScanIterator(ctx, &indexName, expr, optFns...) +} + +// Scan scans the DynamoDB table and returns an iterator of results. +// The scan uses the provided expression. +func (t *Table[T]) Scan(ctx context.Context, expr expression.Expression, optFns ...func(*dynamodb.Options)) iter.Seq[ItemResult[*T]] { + return t.createScanIterator(ctx, nil, expr, optFns...) +} + +// createQueryIterator returns an iterator that queries a DynamoDB table or index and yields results as ItemResult[*T]. +// It automatically handles pagination and error thresholds using MaxConsecutiveErrors. +// If the number of consecutive errors reaches the threshold, iteration stops. +func (t *Table[T]) createQueryIterator(ctx context.Context, indexName *string, expr expression.Expression, optFns ...func(*dynamodb.Options)) iter.Seq[ItemResult[*T]] { + var consecutiveErrors uint = 0 + var maxConsecutiveErrors = t.options.MaxConsecutiveErrors + if maxConsecutiveErrors == 0 { + maxConsecutiveErrors = DefaultMaxConsecutiveErrors + } + + return func(yield func(ItemResult[*T]) bool) { + var lastEvaluatedKey map[string]types.AttributeValue + + for { + res, err := t.client.Query(ctx, &dynamodb.QueryInput{ + TableName: t.options.Schema.TableName(), + IndexName: indexName, + ConsistentRead: aws.Bool(indexName == nil), + ExclusiveStartKey: lastEvaluatedKey, + KeyConditionExpression: expr.KeyCondition(), + ExpressionAttributeNames: expr.Names(), + ExpressionAttributeValues: expr.Values(), + FilterExpression: expr.Filter(), + ProjectionExpression: expr.Projection(), + Select: types.SelectAllAttributes, + }, optFns...) + + if err != nil { + consecutiveErrors++ + + if !yield(ItemResult[*T]{err: err, table: *t.options.Schema.TableName()}) { + return + } + + if consecutiveErrors >= maxConsecutiveErrors { + return + } + + continue + } + + consecutiveErrors = 0 + + if res == nil { + return + } + + if res != nil && res.Items != nil { + for _, item := range res.Items { + i, err := t.options.Schema.Decode(item) + if err != nil { + if !yield(ItemResult[*T]{err: err, table: *t.options.Schema.TableName()}) { + return + } + + continue + } + + if err := t.applyAfterReadExtensions(i); err != nil { + if !yield(ItemResult[*T]{err: err, table: *t.options.Schema.TableName()}) { + return + } + + continue + } + + if !yield(ItemResult[*T]{item: i, table: *t.options.Schema.TableName()}) { + return + } + } + + lastEvaluatedKey = res.LastEvaluatedKey + } else { + lastEvaluatedKey = nil + } + + if lastEvaluatedKey == nil { + return + } + } + } +} + +// QueryIndex queries a DynamoDB index and returns an iterator of results. +// The query uses the provided index name and expression. +func (t *Table[T]) QueryIndex(ctx context.Context, indexName string, expr expression.Expression, optFns ...func(*dynamodb.Options)) iter.Seq[ItemResult[*T]] { + return t.createQueryIterator(ctx, &indexName, expr, optFns...) +} + +// Query queries the DynamoDB table and returns an iterator of results. +// The query uses the provided expression. +func (t *Table[T]) Query(ctx context.Context, expr expression.Expression, optFns ...func(*dynamodb.Options)) iter.Seq[ItemResult[*T]] { + return t.createQueryIterator(ctx, nil, expr, optFns...) +} + +// CreateBatchWriteOperation creates a new BatchWriteOperation for the table. +// Use this to perform batched put and delete operations for the table's items. +func (t *Table[T]) CreateBatchWriteOperation() *BatchWriteOperation[T] { + return NewBatchWriteOperation(t) +} + +// CreateBatchGetOperation creates a new BatchGetOperation for the table. +// Use this to perform batched reads for the table's items. +func (t *Table[T]) CreateBatchGetOperation() *BatchGetOperation[T] { + return NewBatchGetOperation(t) +} diff --git a/feature/dynamodb/entitymanager/table_item_operations_test.go b/feature/dynamodb/entitymanager/table_item_operations_test.go new file mode 100644 index 00000000000..e458eb1d82e --- /dev/null +++ b/feature/dynamodb/entitymanager/table_item_operations_test.go @@ -0,0 +1,532 @@ +package entitymanager + +import ( + "context" + "errors" + "fmt" + "math" + rand2 "math/rand/v2" + "reflect" + "strconv" + "strings" + "testing" + + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func makeField(name string, t reflect.Type) types.AttributeValue { + k := t.Kind() + switch k { + case reflect.String: + if strings.Contains(name, "version") { + return &types.AttributeValueMemberS{ + Value: fmt.Sprintf("%d", rand2.Int32N(100)), + } + } + + return &types.AttributeValueMemberS{ + Value: strings.Repeat(string(byte(rand2.UintN(93)+33)), rand2.IntN(100)), + } + case reflect.Int64: + return &types.AttributeValueMemberN{ + Value: fmt.Sprintf("%d", rand2.Int64N(math.MaxInt)), + } + case reflect.Float64: + return &types.AttributeValueMemberN{ + Value: fmt.Sprintf("%d.%d", rand2.Int64N(math.MaxInt), rand2.Int64N(math.MaxInt)), + } + case reflect.Map: + m := map[string]types.AttributeValue{} + for c := range 10 { + s := fmt.Sprintf("%d", c) + m[s] = makeField(s, reflect.TypeFor[string]()) + } + return &types.AttributeValueMemberM{ + Value: m, + } + case reflect.Slice, reflect.Array: + l := []types.AttributeValue{} + for range 10 { + l = append(l, makeField(name, t.Elem())) + } + return &types.AttributeValueMemberL{ + Value: l, + } + } + return nil +} + +func makeItem[T any]() map[string]types.AttributeValue { + s, _ := NewSchema[T]() + + out := map[string]types.AttributeValue{} + + for _, f := range s.cachedFields.All() { + out[f.Name] = makeField(f.Name, f.Type) + } + + return out +} + +func assertField[V any](t *testing.T, i map[string]types.AttributeValue, key string, value V) { + var rv V + err := NewDecoder[V]().Decode(i[key], &rv) + if err != nil { + t.Errorf(`unable to decode "%v"`, i[key]) + return + } + if diff := cmpDiff(rv, value); diff != "" { + t.Errorf(`enexpected diff for "%s": %v`, key, diff) + } +} + +func assertItem(t *testing.T, i map[string]types.AttributeValue, o *order) { + if o == nil { + t.Error(`order is nil`) + return + } + + assertField(t, i, "order_id", o.OrderID) + assertField(t, i, "customer_id", o.CustomerID) + //assertField(t, i, "versionString", o.VersionString) + assertField(t, i, "street", o.Street) + assertField(t, i, "city", o.City) + assertField(t, i, "zip", o.Zip) + assertField(t, i, "note", o.customerNote) + assertField(t, i, "first_name", o.CustomerFirstName) + assertField(t, i, "last_name", o.CustomerLastName) + //assertField(t, i, "created_at", o.CreatedAt) + // float fields have garbage :( + //assertField(t, i, "total", o.TotalAmount) + //assertField(t, i, "version", o.Version) + //assertField(t, i, "counter_up", o.CounterUp) + //assertField(t, i, "counter_down", o.CounterDown) +} + +func TestTableCreate(t *testing.T) { + cases := []struct { + client Client + expectedError bool + }{ + { + client: newMockClient( + withDefaultCreateTableCall(nil), + withExpectFns(expectTablesCount(1)), + withExpectFns(expectTable("order")), + ), + }, + { + client: newMockClient( + withDefaultCreateTableCall(errors.New("1")), + withExpectFns(expectTablesCount(0)), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + + _, err = table.Create(context.Background()) + if c.expectedError && err == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestTableDescribe(t *testing.T) { + cases := []struct { + client Client + expectedError bool + }{ + { + client: newMockClient( + withDefaultCreateTableCall(nil), + withDefaultDescribeTableCall(nil), + withExpectFns(expectTablesCount(1)), + withExpectFns(expectTable("order")), + ), + }, + { + client: newMockClient( + withDefaultCreateTableCall(errors.New("1")), + withDefaultDescribeTableCall(errors.New("1")), + withExpectFns(expectTablesCount(0)), + ), + expectedError: true, + }, + { + client: newMockClient( + withDefaultCreateTableCall(nil), + withDefaultDescribeTableCall(errors.New("1")), + withExpectFns(expectTablesCount(1)), + withExpectFns(expectTable("order")), + ), + expectedError: true, + }, + { + client: newMockClient( + withDefaultCreateTableCall(errors.New("1")), + withDefaultDescribeTableCall(nil), + withExpectFns(expectTablesCount(0)), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + + _, _ = table.Create(context.Background()) + + _, err = table.Describe(context.Background()) + if c.expectedError && err == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestTableDelete(t *testing.T) { + cases := []struct { + client Client + expectedError bool + }{ + { + client: newMockClient( + withDefaultCreateTableCall(nil), + withDefaultDeleteTableCall(nil), + withExpectFns(expectTablesCount(0)), + ), + }, + { + client: newMockClient( + withDefaultCreateTableCall(errors.New("1")), + withDefaultDeleteTableCall(nil), + withExpectFns(expectTablesCount(0)), + ), + expectedError: true, + }, + { + client: newMockClient( + withDefaultCreateTableCall(errors.New("1")), + withDefaultDeleteTableCall(nil), + withExpectFns(expectTablesCount(0)), + ), + expectedError: true, + }, + { + client: newMockClient( + withDefaultCreateTableCall(nil), + withDefaultDeleteTableCall(errors.New("1")), + withExpectFns(expectTablesCount(1)), + withExpectFns(expectTable("order")), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + + _, _ = table.Create(context.Background()) + + _, err = table.Delete(context.Background()) + if c.expectedError && err == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestTableGetItem(t *testing.T) { + cases := []struct { + client Client + expectedError bool + }{ + { + client: newMockClient( + withDefaultGetItemCall(nil), + withItem("order", makeItem[order]()), + ), + }, + { + client: newMockClient( + withDefaultGetItemCall(errors.New("1")), + withItem("order", makeItem[order]()), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + + _, err = table.GetItem(context.Background(), Map{}) + if c.expectedError && err == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestTablePutItem(t *testing.T) { + cases := []struct { + client Client + expectedError bool + }{ + { + client: newMockClient( + withDefaultPutItemCall(nil), + withExpectFns(expectItemsCount("order", 1)), + ), + }, + { + client: newMockClient( + withDefaultPutItemCall(errors.New("1")), + withExpectFns(expectItemsCount("order", 0)), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + + _, err = table.PutItem(context.Background(), &order{}) + if c.expectedError && err == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestTableUpdateItem(t *testing.T) { + cases := []struct { + client Client + expectedError bool + }{ + { + client: newMockClient( + withDefaultUpdateItemCall(nil), + withExpectFns(expectItemsCount("order", 1)), + ), + }, + { + client: newMockClient( + withDefaultUpdateItemCall(errors.New("1")), + withExpectFns(expectItemsCount("order", 0)), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + + _, err = table.UpdateItem(context.Background(), &order{}) + if c.expectedError && err == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestTableDeleteItem(t *testing.T) { + cases := []struct { + client Client + expectedError bool + }{ + { + client: newMockClient( + withItems("order", makeItem[order], 2), + withDefaultDeleteItemCall(nil), + withExpectFns(expectItemsCount("order", 1)), + ), + }, + { + client: newMockClient( + withItems("order", makeItem[order], 2), + withDefaultDeleteItemCall(errors.New("1")), + withExpectFns(expectItemsCount("order", 2)), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + + err = table.DeleteItem(context.Background(), &order{}) + if c.expectedError && err == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + }) + } +} + +func TestTableQuery(t *testing.T) { + cases := []struct { + client Client + expectedError bool + }{ + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withDefaultQueryCall(nil, 9), + withDefaultQueryCall(nil, 8), + withDefaultQueryCall(nil, 7), + withDefaultQueryCall(nil, 6), + withDefaultQueryCall(nil, 0), + withExpectFns(expectItemsCount("order", 2)), + ), + }, + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withDefaultQueryCall(errors.New("1"), 0), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + + for res := range table.Query(context.Background(), expression.Expression{}) { + if c.expectedError && res.Error() == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && res.Error() != nil { + t.Fatalf("unexpected error: %v", res.Error()) + } + } + }) + } +} + +func TestTableScan(t *testing.T) { + cases := []struct { + client Client + expectedError bool + }{ + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withDefaultScanCall(nil, 9), + withDefaultScanCall(nil, 8), + withDefaultScanCall(nil, 7), + withDefaultScanCall(nil, 6), + withDefaultScanCall(nil, 0), + withExpectFns(expectItemsCount("order", 2)), + ), + }, + { + client: newMockClient( + withItems("order", makeItem[order], 32), + withDefaultScanCall(errors.New("1"), 0), + ), + expectedError: true, + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + defer c.client.(*mockClient).RunExpectations(t) + + table, err := NewTable[order](c.client) + if err != nil { + t.Errorf("unexpcted table error: %v", err) + } + + for res := range table.Scan(context.Background(), expression.Expression{}) { + if c.expectedError && res.Error() == nil { + t.Fatalf("expected error but got none") + } + + if !c.expectedError && res.Error() != nil { + t.Fatalf("unexpected error: %v", res.Error()) + } + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/table_schema_operations.go b/feature/dynamodb/entitymanager/table_schema_operations.go new file mode 100644 index 00000000000..76ca7fdcce7 --- /dev/null +++ b/feature/dynamodb/entitymanager/table_schema_operations.go @@ -0,0 +1,79 @@ +package entitymanager + +import ( + "context" + "errors" + "time" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func (t *Table[T]) Create(ctx context.Context) (*dynamodb.CreateTableOutput, error) { + input, err := t.options.Schema.createTableInput() + if err != nil { + return nil, err + } + + return t.client.CreateTable(ctx, input, t.options.DynamoDBOptions...) +} + +func (t *Table[T]) CreateWithWait(ctx context.Context, maxWaitDur time.Duration) error { + cto, err := t.Create(ctx) + if err != nil { + return err + } + + waiter := dynamodb.NewTableExistsWaiter(t.client) + + return waiter.Wait(ctx, &dynamodb.DescribeTableInput{ + TableName: cto.TableDescription.TableName, + }, maxWaitDur) +} + +func (t *Table[T]) Describe(ctx context.Context) (*dynamodb.DescribeTableOutput, error) { + describe, err := t.options.Schema.describeTableInput() + if err != nil { + return nil, err + } + + return t.client.DescribeTable(ctx, describe, t.options.DynamoDBOptions...) +} + +func (t *Table[T]) Delete(ctx context.Context) (*dynamodb.DeleteTableOutput, error) { + dlt, err := t.options.Schema.deleteTableInput() + if err != nil { + return nil, err + } + + return t.client.DeleteTable(ctx, dlt, t.options.DynamoDBOptions...) +} + +func (t *Table[T]) DeleteWithWait(ctx context.Context, maxWaitDur time.Duration) error { + dlt, err := t.Delete(ctx) + if err != nil { + return err + } + + waiter := dynamodb.NewTableNotExistsWaiter(t.client) + + return waiter.Wait(ctx, &dynamodb.DescribeTableInput{ + TableName: dlt.TableDescription.TableName, + }, maxWaitDur) +} + +func (t *Table[T]) Exists(ctx context.Context) (bool, error) { + _, err := t.Describe(ctx) + + if err != nil { + var notFound *types.ResourceNotFoundException + + if ok := errors.As(err, ¬Found); ok { + return false, nil + } + + return false, err + } + + return true, nil +} diff --git a/feature/dynamodb/entitymanager/table_serde.go b/feature/dynamodb/entitymanager/table_serde.go new file mode 100644 index 00000000000..39026e531ce --- /dev/null +++ b/feature/dynamodb/entitymanager/table_serde.go @@ -0,0 +1,88 @@ +package entitymanager + +import ( + "reflect" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func (s *Schema[T]) Encode(t *T) (map[string]types.AttributeValue, error) { + v := reflect.ValueOf(t) + out := map[string]types.AttributeValue{} + + for _, f := range s.cachedFields.All() { + var fv reflect.Value + var err error + + if f.Tag.Getter != "" { + m := v.MethodByName(f.Tag.Getter) + fv = m.Call([]reflect.Value{})[0] + } else { + fv, err = v.Elem().FieldByIndexErr(f.Index) + if err != nil { + if unwrap(s.options.ErrorOnMissingField) { + return nil, err + } + + continue + } + } + + av, err := s.enc.encode(fv, f.Tag) + if err != nil && unwrap(s.options.ErrorOnMissingField) { + return nil, err + } + + out[f.Name] = av + } + + return out, nil +} + +func (s *Schema[T]) Decode(m map[string]types.AttributeValue) (*T, error) { + t := new(T) + v := reflect.ValueOf(t) + + for _, f := range s.cachedFields.All() { + av, ok := m[f.Name] + if !ok { + continue + } + + if f.Tag.Setter != "" && f.Tag.Getter != "" { + current := v.MethodByName(f.Tag.Getter). + Call([]reflect.Value{})[0] + + if current.Kind() != reflect.Ptr { + current = reflect.New(current.Type()) + } + + if err := s.dec.decode(av, current, f.Tag); err != nil { + return nil, err + } + + v.MethodByName(f.Tag.Setter). + Call([]reflect.Value{ + current.Elem(), + }) + + continue + } + + fv, err := v.Elem().FieldByIndexErr(f.Index) + if err != nil { + if unwrap(s.options.ErrorOnMissingField) { + return nil, err + } + + continue + } + + err = s.dec.decode(av, fv, f.Tag) + if err != nil { + return nil, err + } + } + + return t, nil +} diff --git a/feature/dynamodb/entitymanager/table_test.go b/feature/dynamodb/entitymanager/table_test.go new file mode 100644 index 00000000000..be19d5ca560 --- /dev/null +++ b/feature/dynamodb/entitymanager/table_test.go @@ -0,0 +1,421 @@ +package entitymanager + +import ( + "context" + "fmt" + "log" + "net/http" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression" + + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +type MyAuditExtension struct{} + +func (a *MyAuditExtension) BeforeWrite(ctx context.Context, v *order) error { + log.Printf("Audit: about to write item: %+v", v.OrderID) + return nil +} + +func (a *MyAuditExtension) AfterRead(ctx context.Context, v *order) error { + log.Printf("Audit: read item: %+v", v.OrderID) + return nil +} + +func TestTableE2E(t *testing.T) { + t.Parallel() + + // Constants for test configuration + const ( + itemsToManage = 128 + tagCount = 30 + batchCount = 128 + ) + if testing.Short() { + t.Skip("skipping test in short mode.") + } + + tableName := fmt.Sprintf("test_e2e_%s", time.Now().Format("2006_01_02_15_04_05.000000000")) + + sch, err := NewSchema[order]() + if err != nil { + t.Fatalf("NewSchema() error: %v", err) + } + + sch.WithTableName(pointer(tableName)) + + { + var tags []types.Tag + for i := 0; i < tagCount; i++ { + tags = append(tags, types.Tag{ + Key: pointer(fmt.Sprintf("key%d", i)), + Value: pointer(fmt.Sprintf("value%d", i)), + }) + } + sch.WithTags(tags) + } + + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + t.Fatalf("Error loading config: %v", err) + } + c := dynamodb.NewFromConfig(cfg) + + ext := &MyAuditExtension{} + registry := DefaultExtensionRegistry[order]().Clone() + registry.AddBeforeWriter(ext) + registry.AddAfterReader(ext) + + tbl, err := NewTable[order]( + c, + WithSchema(sch), + WithExtensionRegistry(registry), + ) + if err != nil { + t.Fatalf("NewTable() error: %v", err) + } + + // create + t.Logf("Table %s will be created", tableName) + err = tbl.CreateWithWait(context.Background(), time.Minute*5) + if err != nil { + t.Fatalf("CreateWithWait() error: %v", err) + } + t.Logf("Table %s ready", tableName) + + // exists + t.Logf("Table %s will be checked if it exists", tableName) + exists, err := tbl.Exists(context.Background()) + if err != nil { + t.Fatalf("Exists() error: %v", err) + } + if exists != true { + t.Fatal("Expected table to exist") + } + t.Logf("Table %s exists", tableName) + + // defer table delete + t.Cleanup(func() { + t.Logf("Table %s will be deleted", tableName) + if err := tbl.DeleteWithWait(context.Background(), time.Minute); err != nil { + t.Errorf("DeleteWithWait() error: %v", err) + } else { + t.Logf("Table %s deleted", tableName) + } + }) + + orderIds := make([]string, itemsToManage) + createdAts := make([]int64, itemsToManage) + + // Helper for order creation + createOrder := func(i int) *order { + return &order{ + CustomerID: fmt.Sprintf("CustomerID%d", i), + TotalAmount: float64(i), + IgnoredField: fmt.Sprintf("IgnoredField%d", i), + Version: 0, + VersionString: "0", + CounterUp: 0, + CounterDown: 0, + Metadata: map[string]string{ + "test": "test", + }, + address: address{ + Street: fmt.Sprintf("Street%d", i), + City: fmt.Sprintf("City%d", i), + Zip: fmt.Sprintf("Zip%d", i), + }, + Notes: []string{fmt.Sprintf("Notes%d", i)}, + customerNote: fmt.Sprintf("customerNote%d", i), + CustomerFirstName: fmt.Sprintf("CustomerFirstName%d", i), + CustomerLastName: fmt.Sprintf("CustomerLastName%d", i), + } + } + + // Put() + for i := 0; i < itemsToManage; i++ { + o := createOrder(i) + item, err := tbl.PutItem(context.Background(), o) + if err != nil { + t.Errorf("Unable to PutItem() [%d]: %v", i, err) + continue + } + + orderIds[i] = item.OrderID + createdAts[i] = item.CreatedAt + + t.Logf("PutItem: %s - %d", item.OrderID, item.CreatedAt) + t.Logf("\tVersion: %d - %s", item.Version, item.VersionString) + t.Logf("\tCounter (Up/Down): %d/%d", item.CounterUp, item.CounterDown) + } + + // Get() + Update() + for i := 0; i < itemsToManage; i++ { + m := Map{}. + With("order_id", orderIds[i]). + With("created_at", createdAts[i]) + item, err := tbl.GetItem(context.Background(), m) + if err != nil { + t.Errorf("Unable to GetItem() [%s]: %v", m, err) + continue + } + t.Logf("GetItem: %s - %d", item.OrderID, item.CreatedAt) + t.Logf("\tVersion: %d - %s", item.Version, item.VersionString) + t.Logf("\tCounter (Up/Down): %d/%d", item.CounterUp, item.CounterDown) + + item.TotalAmount *= 2 + + item, err = tbl.UpdateItem(context.Background(), item) + if err != nil { + t.Errorf("Unable to UpdateItem() [%s]: %v", m, err) + continue + } + t.Logf("UpdateItem: %s - %d", item.OrderID, item.CreatedAt) + t.Logf("\tVersion: %d - %s", item.Version, item.VersionString) + t.Logf("\tCounter (Up/Down): %d/%d", item.CounterUp, item.CounterDown) + } + + { + t.Log("Scan()") + scanExpr := expression.Expression{} + items := tbl.Scan(context.Background(), scanExpr) + scannedItems := 0 + for res := range items { + if res.Error() != nil { + t.Errorf("Error during Scan(): %v", res.Error()) + } + item := res.Item() + t.Logf("Scan: %s - %d", item.OrderID, item.CreatedAt) + t.Logf("\tVersion: %d - %s", item.Version, item.VersionString) + t.Logf("\tCounter (Up/Down): %d/%d", item.CounterUp, item.CounterDown) + + scannedItems++ + } + if scannedItems != itemsToManage { + t.Errorf("Scanned %d item(s), expected %d", scannedItems, itemsToManage) + } + } + + { + t.Log("ScanIndex()") + scanExpr := expression.Expression{} + items := tbl.ScanIndex(context.Background(), "CustomerIndex", scanExpr) + scannedItems := 0 + for res := range items { + if res.Error() != nil { + t.Errorf("Error during ScanIndex(): %v", res.Error()) + + continue + } + + item := res.Item() + if item != nil { + t.Logf("Scan: %s - %d", item.OrderID, item.CreatedAt) + t.Logf("\tVersion: %d - %s", item.Version, item.VersionString) + t.Logf("\tCounter (Up/Down): %d/%d", item.CounterUp, item.CounterDown) + } else { + t.Log("no error and item was nil :(") + } + + scannedItems++ + } + if scannedItems != itemsToManage { + t.Errorf("Scanned %d item(s), expected %d", scannedItems, itemsToManage) + } + } + + knowVersions := map[string]int64{} + { + t.Log("Query()") + queriedItems := 0 + for i := range itemsToManage { + queryExprBuilder := expression.NewBuilder() + queryExprBuilder = queryExprBuilder.WithKeyCondition( + expression.Key("order_id").Equal(expression.Value(orderIds[i])).And( + expression.Key("created_at").Equal(expression.Value(createdAts[i])), + ), + ) + queryExpr, err := queryExprBuilder.Build() + if err != nil { + t.Errorf("Unable to build query: %v", err) + + return + } + + items := tbl.Query(context.Background(), queryExpr) + for res := range items { + if res.Error() != nil { + t.Errorf("Error during Query(): %v", res.Error()) + + continue + } + + item := res.Item() + t.Logf("Query: %s - %d", item.OrderID, item.CreatedAt) + t.Logf("\tVersion: %d - %s", item.Version, item.VersionString) + t.Logf("\tCounter (Up/Down): %d/%d", item.CounterUp, item.CounterDown) + + knowVersions[item.OrderID] = item.Version + + queriedItems++ + } + } + if queriedItems != itemsToManage { + t.Errorf("Queried %d item(s), expected %d", queriedItems, itemsToManage) + } + } + + { + t.Log("QueryIndex()") + queriedItems := 0 + for orderID, version := range knowVersions { + queryExprBuilder := expression.NewBuilder() + queryExprBuilder = queryExprBuilder.WithKeyCondition( + expression.Key("order_id").Equal(expression.Value(orderID)).And( + expression.Key("version").Equal(expression.Value(version)), + ), + ) + queryExpr, err := queryExprBuilder.Build() + if err != nil { + t.Errorf("Unable to build query: %v", err) + + return + } + + items := tbl.QueryIndex(context.Background(), "OrderVersionIndex", queryExpr) + for res := range items { + if res.Error() != nil { + t.Errorf("Error during QueryIndex(): %v", res.Error()) + + continue + } + + item := res.Item() + t.Logf("Query: %s - %d", item.OrderID, item.CreatedAt) + t.Logf("\tVersion: %d - %s", item.Version, item.VersionString) + t.Logf("\tCounter (Up/Down): %d/%d", item.CounterUp, item.CounterDown) + + queriedItems++ + } + } + + if queriedItems != itemsToManage { + t.Errorf("Queried %d item(s), expected %d", queriedItems, itemsToManage) + } + } + + // batch + { + bwo := tbl.CreateBatchWriteOperation() + batchItems := make([]order, batchCount) + for i := 0; i < batchCount; i++ { + batchItems[i] = *createOrder(i) + if err := bwo.AddPut(&batchItems[i]); err != nil { + t.Error(err.Error()) + } + } + + if err := bwo.Execute(context.Background()); err != nil { + t.Error(err.Error()) + } else { + t.Log("BatchWritePut done") + } + for _, batchItem := range batchItems { + t.Logf("OrderID: %s", batchItem.OrderID) + } + + // get + bgo := tbl.CreateBatchGetOperation() + for i := range batchItems { + if err := bgo.AddReadItem(&batchItems[i]); err != nil { + t.Error(err.Error()) + } + } + + for item := range bgo.Execute(context.Background()) { + if item.Error() != nil { + t.Errorf("error during BatchGetOperation iteration: %v", item.Error()) + continue + } + + if item.Item() == nil { + t.Error("nil item returned") + continue + } + + found := false + for i := range batchItems { + if batchItems[i].OrderID == item.Item().OrderID { + found = true + break // optimization: break on first match + } + } + if !found { + t.Errorf("item not in initial query returned: %s", item.Item().OrderID) + } + } + + // delete + bwod := tbl.CreateBatchWriteOperation() + for i := range batchItems { + if err := bwod.AddDelete(&batchItems[i]); err != nil { + t.Error(err.Error()) + } + } + + if err := bwod.Execute(context.Background()); err != nil { + t.Error(err.Error()) + } + } +} + +type captureHTTPClient struct { + req *http.Request +} + +func (c *captureHTTPClient) Do(req *http.Request) (*http.Response, error) { + c.req = req + return nil, nil +} + +func TestTableAddsHeaderToClient(t *testing.T) { + ctx := context.Background() + httpClient := &captureHTTPClient{} + + cfg, err := config.LoadDefaultConfig(ctx, config.WithHTTPClient(httpClient)) + if err != nil { + t.Fatalf("LoadDefaultConfig() error: %v", err) + } + + client := dynamodb.NewFromConfig(cfg) + + tbl, err := NewTable[order](client) + if err != nil { + t.Fatalf("NewTable() error: %v", err) + } + + _, _ = tbl.GetItem(ctx, Map{}) + + if httpClient.req == nil { + t.Fatal("expected HTTP request to be captured, got nil") + } + + ua := httpClient.req.Header.Get("User-Agent") + t.Logf(`Found user agent: "%s"`, ua) + if ua == "" { + t.Fatal("expected User-Agent header to be set, got empty string") + } + + if !strings.Contains(ua, UserAgentPart) { + t.Fatalf("expected User-Agent header to contain %q, got %q", UserAgentPart, ua) + } + if !strings.Contains(ua, EntityManagerVersion) { + t.Fatalf("expected User-Agent header to contain %q, got %q", EntityManagerVersion, ua) + } +} diff --git a/feature/dynamodb/entitymanager/tag.go b/feature/dynamodb/entitymanager/tag.go new file mode 100644 index 00000000000..afdff39a06d --- /dev/null +++ b/feature/dynamodb/entitymanager/tag.go @@ -0,0 +1,192 @@ +package entitymanager + +import ( + "reflect" + "strings" +) + +const ( + defaultTagKey = "dynamodbav" + tagIndex = "dynamodbindex" + tagGetter = "dynamodbgetter" + tagSetter = "dynamodbsetter" +) + +// Tag holds parsed metadata from struct field tags for DynamoDB attribute mapping. +// It captures options such as name overrides, omitempty, index info, custom converters, and more. +// Used internally for encoding/decoding and schema management. +type Tag struct { + Name string // `dynamodbav` + Ignore bool // `dynamodbav:"-" + OmitEmpty bool // `dynamodbav:",omitempty"` + OmitEmptyElem bool // `dynamodbav:",omitemptyelem"` + NullEmpty bool // `dynamodbav:",nullempty"` + NullEmptyElem bool // `dynamodbav:",nullemptyelem"` + AsString bool // `dynamodbav:",string"` + AsBinSet bool // `dynamodbav:",binaryset"` + AsNumSet bool // `dynamodbav:",numberset"` + AsStrSet bool // `dynamodbav:",stringset"` + AsUnixTime bool // `dynamodbav:",unixtime"` + Version bool // `dynamodbav:",version"` + PreserveEmpty bool // `dynamodbav:",preserveempty"` + JSON bool // `dynamodbav:",json"` + AutoGenerated bool // `dynamodbav:",autogenerated"` + AtomicCounter bool // `dynamodbav:",atomiccounter"` OR `dynamodbav:",atomiccounter|10|5"` (with startValue and delta) + EnumAsString bool // `dynamodbav:",enumasstring"` + Partition bool // `dynamodbav:",partition"` + Sort bool // `dynamodbav:",sort"` + Converter bool // `dynamodbav:",converter|int64"` + Getter string // @DynamoDBGetter = "dynamodbgetter" + Setter string // @DynamoDBSetter = "dynamodbsetter" + Indexes []Index // @DynamoDBIndex = "dynamodbindex" + Options map[string][]string // keys can be written as : +} + +// Option returns the parsed options for a given key from the struct tag, if present. +// It is used to retrieve custom tag options such as converter types or atomic counter values. +func (t *Tag) Option(k string) ([]string, bool) { + if t.Options == nil { + return nil, false + } + + v, ok := t.Options[k] + + return v, ok +} + +func (t *Tag) parseAVTag(structTag reflect.StructTag) { + tagStr := structTag.Get(defaultTagKey) + if len(tagStr) == 0 { + return + } + + t.parseTagStr(tagStr) + + t.parseIndexTag(structTag) + + t.parseGetterAndSetter(structTag) +} + +func (t *Tag) parseStructTag(tag string, structTag reflect.StructTag) { + tagStr := structTag.Get(tag) + if len(tagStr) == 0 { + return + } + + t.parseTagStr(tagStr) +} + +func (t *Tag) parseTagStr(tagStr string) { + parts := strings.Split(tagStr, ",") + if len(parts) == 0 { + return + } + + if name := parts[0]; name == "-" { + t.Name = "" + t.Ignore = true + } else { + t.Name = name + t.Ignore = false + } + + for _, opt := range parts[1:] { + if strings.Contains(opt, "|") { + if t.Options == nil { + t.Options = map[string][]string{} + } + + subOpts := strings.Split(opt, "|") + opt = subOpts[0] + subOpts = subOpts[1:] + if _, ok := t.Options[opt]; ok { + panic("tag already present with options") + } else { + t.Options[opt] = subOpts + } + } + + switch opt { + case "omitempty": + t.OmitEmpty = true + case "omitemptyelem": + t.OmitEmptyElem = true + case "nullempty": + t.NullEmpty = true + case "nullemptyelem": + t.NullEmptyElem = true + case "string": + t.AsString = true + case "binaryset": + t.AsBinSet = true + case "numberset": + t.AsNumSet = true + case "stringset": + t.AsStrSet = true + case "unixtime": + t.AsUnixTime = true + case "version": + t.Version = true + case "preserveempty": + t.PreserveEmpty = true + case "json": + t.JSON = true + case "enumasstring": + t.EnumAsString = true + case "partition": + t.Partition = true + case "sort": + t.Sort = true + case "autogenerated": + t.AutoGenerated = true + case "atomiccounter": + t.AtomicCounter = true + case "converter": + t.Converter = true + default: + continue + } + } +} + +func (t *Tag) parseIndexTag(structTag reflect.StructTag) { + idxString := structTag.Get(tagIndex) + if len(idxString) == 0 { + return + } + + indexes := strings.Split(idxString, ";") + + if len(indexes) == 0 { + return + } + + t.Indexes = make([]Index, len(indexes)) + + for c, index := range indexes { + parts := strings.Split(index, ",") + if len(parts) == 0 { + continue + } + + t.Indexes[c].Name = parts[0] + + for _, part := range parts[1:] { + switch part { + case "global": + t.Indexes[c].Global = true + case "local": + t.Indexes[c].Local = true + case "partition": + t.Indexes[c].Partition = true + case "sort": + t.Indexes[c].Sort = true + } + } + } +} + +func (t *Tag) parseGetterAndSetter(structTag reflect.StructTag) { + t.Getter = structTag.Get(tagGetter) + t.Setter = structTag.Get(tagSetter) +} diff --git a/feature/dynamodb/entitymanager/tag_test.go b/feature/dynamodb/entitymanager/tag_test.go new file mode 100644 index 00000000000..122bec7842d --- /dev/null +++ b/feature/dynamodb/entitymanager/tag_test.go @@ -0,0 +1,80 @@ +package entitymanager + +import ( + "fmt" + "reflect" + "testing" +) + +func TestTagParse(t *testing.T) { + cases := []struct { + in reflect.StructTag + json, av bool + expect Tag + }{ + {`json:""`, true, false, Tag{}}, + {`json:"name"`, true, false, Tag{Name: "name"}}, + {`json:"name,omitempty"`, true, false, Tag{Name: "name", OmitEmpty: true}}, + {`json:"-"`, true, false, Tag{Ignore: true}}, + {`json:",omitempty"`, true, false, Tag{OmitEmpty: true}}, + {`json:",string"`, true, false, Tag{AsString: true}}, + {`dynamodbav:""`, false, true, Tag{}}, + {`dynamodbav:","`, false, true, Tag{}}, + {`dynamodbav:"name"`, false, true, Tag{Name: "name"}}, + {`dynamodbav:"name"`, false, true, Tag{Name: "name"}}, + {`dynamodbav:"-"`, false, true, Tag{Ignore: true}}, + {`dynamodbav:",omitempty"`, false, true, Tag{OmitEmpty: true}}, + {`dynamodbav:",omitemptyelem"`, false, true, Tag{OmitEmptyElem: true}}, + {`dynamodbav:",string"`, false, true, Tag{AsString: true}}, + {`dynamodbav:",binaryset"`, false, true, Tag{AsBinSet: true}}, + {`dynamodbav:",numberset"`, false, true, Tag{AsNumSet: true}}, + {`dynamodbav:",stringset"`, false, true, Tag{AsStrSet: true}}, + {`dynamodbav:",stringset,omitemptyelem"`, false, true, Tag{AsStrSet: true, OmitEmptyElem: true}}, + {`dynamodbav:"name,stringset,omitemptyelem"`, false, true, Tag{Name: "name", AsStrSet: true, OmitEmptyElem: true}}, + {`dynamodbav:",version"`, false, true, Tag{Version: true}}, + {`dynamodbav:",preserveempty"`, false, true, Tag{PreserveEmpty: true}}, + {`dynamodbav:",json"`, false, true, Tag{JSON: true}}, + {`dynamodbav:",autogenerated|key"`, false, true, Tag{AutoGenerated: true, Options: map[string][]string{"autogenerated": {"key"}}}}, + {`dynamodbav:",autogenerated|timestamp"`, false, true, Tag{AutoGenerated: true, Options: map[string][]string{"autogenerated": {"timestamp"}}}}, + {`dynamodbav:",atomiccounter"`, false, true, Tag{AtomicCounter: true}}, + {`dynamodbav:",atomiccounter|start=0"`, false, true, Tag{AtomicCounter: true, Options: map[string][]string{"atomiccounter": {"start=0"}}}}, + {`dynamodbav:",atomiccounter|delta=1"`, false, true, Tag{AtomicCounter: true, Options: map[string][]string{"atomiccounter": {"delta=1"}}}}, + {`dynamodbav:",atomiccounter|delta=-1"`, false, true, Tag{AtomicCounter: true, Options: map[string][]string{"atomiccounter": {"delta=-1"}}}}, + {`dynamodbav:",atomiccounter|start=0|delta=-1"`, false, true, Tag{AtomicCounter: true, Options: map[string][]string{"atomiccounter": {"start=0", "delta=-1"}}}}, + {`dynamodbav:",enumasstring"`, false, true, Tag{EnumAsString: true}}, + {`dynamodbav:",partition"`, false, true, Tag{Partition: true}}, + {`dynamodbav:",sort"`, false, true, Tag{Sort: true}}, + {`dynamodbgetter:"Prop" dynamodbsetter:"SetProp"`, false, true, Tag{Getter: "", Setter: ""}}, + {`dynamodbav:"prop" dynamodbgetter:"Prop" dynamodbsetter:"SetProp"`, false, true, Tag{Name: "prop", Getter: "Prop", Setter: "SetProp"}}, + {`dynamodbav:"prop" dynamodbindex:"idx"`, false, true, Tag{Name: "prop", Indexes: []Index{{Name: "idx"}}}}, + {`dynamodbav:"prop" dynamodbindex:"idx,local"`, false, true, Tag{Name: "prop", Indexes: []Index{{Name: "idx", Local: true}}}}, + {`dynamodbav:"prop" dynamodbindex:"idx,global"`, false, true, Tag{Name: "prop", Indexes: []Index{{Name: "idx", Global: true}}}}, + {`dynamodbav:"prop" dynamodbindex:"idx,partition"`, false, true, Tag{Name: "prop", Indexes: []Index{{Name: "idx", Partition: true}}}}, + {`dynamodbav:"prop" dynamodbindex:"idx,sort"`, false, true, Tag{Name: "prop", Indexes: []Index{{Name: "idx", Sort: true}}}}, + {`dynamodbav:"prop" dynamodbindex:"idx,global,local,partition,sort"`, false, true, Tag{Name: "prop", Indexes: []Index{{Name: "idx", Global: true, Local: true, Partition: true, Sort: true}}}}, + {`dynamodbav:"prop" dynamodbindex:",global,local,partition,sort"`, false, true, Tag{Name: "prop", Indexes: []Index{{Name: "", Global: true, Local: true, Partition: true, Sort: true}}}}, + {`dynamodbav:"prop" dynamodbindex:","`, false, true, Tag{Name: "prop", Indexes: []Index{{}}}}, + {`dynamodbav:"prop" dynamodbindex:""`, false, true, Tag{Name: "prop"}}, + {`dynamodbav:"prop,converter|float64"`, false, true, Tag{Name: "prop", Converter: true, Options: map[string][]string{"converter": {"float64"}}}}, + {`dynamodbav:"prop,converter|time|format=2006-01-02"`, false, true, Tag{Name: "prop", Converter: true, Options: map[string][]string{"converter": {"time", "format=2006-01-02"}}}}, + // unsupported tags are ignored + {`dynamodbav:"prop" dynamodbindex:"idx,unsupportedtag"`, false, true, Tag{Name: "prop", Indexes: []Index{{Name: "idx"}}}}, + {`dynamodbav:",unsupportedtag"`, false, true, Tag{}}, + //{`dynamodbav:",flatten"`, false, true, Tag{Flatten: true}}, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + actual := Tag{} + if c.json { + actual.parseStructTag("json", c.in) + } + if c.av { + actual.parseAVTag(c.in) + } + if e, a := c.expect, actual; !reflect.DeepEqual(e, a) { + t.Errorf("case %d [%s], expect %v, got %v", i, c.in, e, a) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/util.go b/feature/dynamodb/entitymanager/util.go new file mode 100644 index 00000000000..1da80bd0f85 --- /dev/null +++ b/feature/dynamodb/entitymanager/util.go @@ -0,0 +1,41 @@ +package entitymanager + +import ( + "reflect" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func pointer[T any](v T) *T { + return &v +} + +func unwrap[T any](v *T) T { + if v != nil { + return *v + } + + return *new(T) +} + +func typeToScalarAttributeType(t reflect.Type) (types.ScalarAttributeType, bool) { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + switch t.Kind() { + case reflect.String: + return types.ScalarAttributeTypeS, true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return types.ScalarAttributeTypeN, true + case reflect.Slice, reflect.Array: + if t.Elem().Kind() == reflect.Uint8 { + return types.ScalarAttributeTypeB, true + } + fallthrough + default: + return "", false // unknown or unsupported kind + } +} diff --git a/feature/dynamodb/entitymanager/util_test.go b/feature/dynamodb/entitymanager/util_test.go new file mode 100644 index 00000000000..2a19a72cc0d --- /dev/null +++ b/feature/dynamodb/entitymanager/util_test.go @@ -0,0 +1,338 @@ +package entitymanager + +import ( + "fmt" + "reflect" + "testing" + + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +func TestTypeToScalarAttributeType(t *testing.T) { + cases := []struct { + input reflect.Type + expected types.ScalarAttributeType + ok bool + }{ + { + input: reflect.TypeFor[uint](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[uint8](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[uint16](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[uint32](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[uint64](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[int](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[int8](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[int16](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[int32](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[int64](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[float32](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[float64](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[complex64](), + ok: false, + }, + { + input: reflect.TypeFor[complex128](), + ok: false, + }, + { + input: reflect.TypeFor[string](), + expected: types.ScalarAttributeTypeS, + ok: true, + }, + { + input: reflect.TypeFor[[]byte](), + expected: types.ScalarAttributeTypeB, + ok: true, + }, + { + input: reflect.TypeFor[[1]byte](), + expected: types.ScalarAttributeTypeB, + ok: true, + }, + { + input: reflect.TypeFor[*uint](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*uint8](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*uint16](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*uint32](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*uint64](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*int](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*int8](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*int16](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*int32](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*int64](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*float32](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*float64](), + expected: types.ScalarAttributeTypeN, + ok: true, + }, + { + input: reflect.TypeFor[*complex64](), + ok: false, + }, + { + input: reflect.TypeFor[*complex128](), + ok: false, + }, + { + input: reflect.TypeFor[*string](), + expected: types.ScalarAttributeTypeS, + ok: true, + }, + { + input: reflect.TypeFor[*[]byte](), + expected: types.ScalarAttributeTypeB, + ok: true, + }, + { + input: reflect.TypeFor[*[1]byte](), + expected: types.ScalarAttributeTypeB, + ok: true, + }, + { + input: reflect.TypeFor[order](), + expected: "", + ok: false, + }, + { + input: reflect.TypeFor[*order](), + expected: "", + ok: false, + }, + { + input: reflect.TypeFor[[]order](), + expected: "", + ok: false, + }, + { + input: reflect.TypeFor[*[]order](), + expected: "", + ok: false, + }, + { + input: reflect.TypeFor[map[string]string](), + expected: "", + ok: false, + }, + { + input: reflect.TypeFor[*map[string]string](), + expected: "", + ok: false, + }, + { + input: reflect.TypeFor[[]map[string]string](), + expected: "", + ok: false, + }, + { + input: reflect.TypeFor[*[]map[string]string](), + expected: "", + ok: false, + }, + { + input: reflect.TypeFor[any](), + expected: "", + ok: false, + }, + { + input: reflect.TypeFor[[]any](), + expected: "", + ok: false, + }, + { + input: reflect.TypeFor[map[string]string](), + expected: "", + ok: false, + }, + { + input: reflect.TypeFor[chan any](), + expected: "", + ok: false, + }, + } + + for i, c := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + actual, ok := typeToScalarAttributeType(c.input) + + if diff := cmpDiff(c.expected, actual); len(diff) != 0 { + t.Errorf("different values: %s", diff) + } + + if diff := cmpDiff(c.ok, ok); len(diff) != 0 { + t.Errorf("different values: %s", diff) + } + }) + } +} +func TestPointer(t *testing.T) { + type foo struct{ X int } + cases := []struct { + name string + input any + want any + }{ + {"int", 42, 42}, + {"string", "hello", "hello"}, + {"struct", foo{X: 7}, foo{X: 7}}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + switch v := c.input.(type) { + case int: + p := pointer(v) + if p == nil || *p != c.want.(int) { + t.Errorf("pointer(int): got %v, want pointer to %v", p, c.want) + } + case string: + p := pointer(v) + if p == nil || *p != c.want.(string) { + t.Errorf("pointer(string): got %v, want pointer to %v", p, c.want) + } + case foo: + p := pointer(v) + if p == nil || *p != c.want.(foo) { + t.Errorf("pointer(struct): got %v, want pointer to %+v", p, c.want) + } + default: + t.Fatalf("unsupported type: %T", v) + } + }) + } +} + +func TestUnwrap(t *testing.T) { + type foo struct{ X int } + i := 99 + s := "world" + f := foo{X: 123} + + cases := []struct { + name string + input any + want any + }{ + {"*int", &i, i}, + {"nil *int", (*int)(nil), 0}, + {"*string", &s, s}, + {"nil *string", (*string)(nil), ""}, + {"*struct", &f, f}, + {"nil *struct", (*foo)(nil), foo{}}, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + switch v := c.input.(type) { + case *int: + got := unwrap(v) + if got != c.want.(int) { + t.Errorf("unwrap(*int): got %v, want %v", got, c.want) + } + case *string: + got := unwrap(v) + if got != c.want.(string) { + t.Errorf("unwrap(*string): got %v, want %v", got, c.want) + } + case *foo: + got := unwrap(v) + if got != c.want.(foo) { + t.Errorf("unwrap(*struct): got %+v, want %+v", got, c.want) + } + default: + t.Fatalf("unsupported type: %T", v) + } + }) + } +} diff --git a/feature/dynamodb/entitymanager/version.go b/feature/dynamodb/entitymanager/version.go new file mode 100644 index 00000000000..1815beb9b8c --- /dev/null +++ b/feature/dynamodb/entitymanager/version.go @@ -0,0 +1,4 @@ +package entitymanager + +const UserAgentPart = "entity-manager" +const EntityManagerVersion = "0.1.0"