Skip to content

Commit 862e30c

Browse files
authored
Feature/backport-3658-2 (#3877) (#4035)
1 parent 635fe87 commit 862e30c

File tree

2 files changed

+137
-45
lines changed

2 files changed

+137
-45
lines changed

state/aws/dynamodb/dynamodb.go

Lines changed: 92 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ type dynamoDBMetadata struct {
6666
PartitionKey string `json:"partitionKey"`
6767
}
6868

69+
type putData struct {
70+
ConditionExpression *string
71+
ExpressionAttributeValues map[string]types.AttributeValue
72+
Item map[string]types.AttributeValue
73+
TableName *string
74+
}
75+
6976
const (
7077
defaultPartitionKeyName = "key"
7178
metadataPartitionKey = "partitionKey"
@@ -171,9 +178,9 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get
171178
return &state.GetResponse{}, nil
172179
}
173180

174-
var output string
175-
if err = attributevalue.Unmarshal(result.Item["value"], &output); err != nil {
176-
return nil, err
181+
data, err := unmarshalValue(result.Item["value"])
182+
if err != nil {
183+
return nil, fmt.Errorf("dynamodb error: failed to unmarshal value for key %s: %w", req.Key, err)
177184
}
178185

179186
var metadata map[string]string
@@ -194,7 +201,7 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get
194201
}
195202

196203
resp := &state.GetResponse{
197-
Data: []byte(output),
204+
Data: data,
198205
Metadata: metadata,
199206
}
200207

@@ -212,29 +219,12 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get
212219

213220
// Set saves a dynamoDB item.
214221
func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error {
215-
item, err := d.getItemFromReq(req)
222+
pd, err := d.createPutData(req)
216223
if err != nil {
217224
return err
218225
}
219226

220-
input := &dynamodb.PutItemInput{
221-
Item: item,
222-
TableName: &d.table,
223-
}
224-
225-
if req.HasETag() {
226-
condExpr := "etag = :etag"
227-
input.ConditionExpression = &condExpr
228-
exprAttrValues := make(map[string]types.AttributeValue)
229-
exprAttrValues[":etag"] = &types.AttributeValueMemberS{
230-
Value: *req.ETag,
231-
}
232-
input.ExpressionAttributeValues = exprAttrValues
233-
} else if req.Options.Concurrency == state.FirstWrite {
234-
condExpr := "attribute_not_exists(etag)"
235-
input.ConditionExpression = &condExpr
236-
}
237-
_, err = d.dynamodbClient.PutItem(ctx, input)
227+
_, err = d.dynamodbClient.PutItem(ctx, pd.ToPutItemInput())
238228
if err != nil && req.HasETag() {
239229
var cErr *types.ConditionalCheckFailedException
240230
switch {
@@ -298,9 +288,55 @@ func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata
298288
return &m, err
299289
}
300290

301-
// getItemFromReq converts a dapr state.SetRequest into an dynamodb item
302-
func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]types.AttributeValue, error) {
303-
value, err := d.marshalToString(req.Value)
291+
// createPutData creates a DynamoDB put request data from a SetRequest.
292+
func (d *StateStore) createPutData(req *state.SetRequest) (putData, error) {
293+
item, err := d.createItem(req)
294+
if err != nil {
295+
return putData{}, err
296+
}
297+
298+
pd := putData{
299+
Item: item,
300+
TableName: ptr.Of(d.table),
301+
}
302+
303+
if req.HasETag() {
304+
condExpr := "etag = :etag"
305+
pd.ConditionExpression = &condExpr
306+
exprAttrValues := make(map[string]types.AttributeValue)
307+
exprAttrValues[":etag"] = &types.AttributeValueMemberS{
308+
Value: *req.ETag,
309+
}
310+
pd.ExpressionAttributeValues = exprAttrValues
311+
} else if req.Options.Concurrency == state.FirstWrite {
312+
condExpr := "attribute_not_exists(etag)"
313+
pd.ConditionExpression = &condExpr
314+
}
315+
316+
return pd, nil
317+
}
318+
319+
func (d putData) ToPutItemInput() *dynamodb.PutItemInput {
320+
return &dynamodb.PutItemInput{
321+
ConditionExpression: d.ConditionExpression,
322+
ExpressionAttributeValues: d.ExpressionAttributeValues,
323+
Item: d.Item,
324+
TableName: d.TableName,
325+
}
326+
}
327+
328+
func (d putData) ToPut() *types.Put {
329+
return &types.Put{
330+
ConditionExpression: d.ConditionExpression,
331+
ExpressionAttributeValues: d.ExpressionAttributeValues,
332+
Item: d.Item,
333+
TableName: d.TableName,
334+
}
335+
}
336+
337+
// createItem creates a DynamoDB item from a SetRequest.
338+
func (d *StateStore) createItem(req *state.SetRequest) (map[string]types.AttributeValue, error) {
339+
value, err := marshalValue(req.Value)
304340
if err != nil {
305341
return nil, fmt.Errorf("dynamodb error: failed to marshal value for key %s: %w", req.Key, err)
306342
}
@@ -319,9 +355,7 @@ func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]types.Att
319355
d.partitionKey: &types.AttributeValueMemberS{
320356
Value: req.Key,
321357
},
322-
"value": &types.AttributeValueMemberS{
323-
Value: value,
324-
},
358+
"value": value,
325359
"etag": &types.AttributeValueMemberS{
326360
Value: strconv.FormatUint(newEtag, 16),
327361
},
@@ -346,12 +380,35 @@ func getRand64() (uint64, error) {
346380
return binary.LittleEndian.Uint64(randBuf), nil
347381
}
348382

349-
func (d *StateStore) marshalToString(v interface{}) (string, error) {
350-
if buf, ok := v.([]byte); ok {
351-
return string(buf), nil
383+
func marshalValue(v interface{}) (types.AttributeValue, error) {
384+
if bt, ok := v.([]byte); ok {
385+
return &types.AttributeValueMemberB{Value: bt}, nil
386+
}
387+
388+
str, err := jsoniterator.ConfigFastest.MarshalToString(v)
389+
if err != nil {
390+
return nil, err
391+
}
392+
393+
return &types.AttributeValueMemberS{Value: str}, nil
394+
}
395+
396+
func unmarshalValue(value types.AttributeValue) ([]byte, error) {
397+
if value == nil {
398+
return []byte(nil), nil
399+
}
400+
401+
var bytes []byte
402+
if err := attributevalue.Unmarshal(value, &bytes); err == nil {
403+
return bytes, nil
404+
}
405+
406+
var str string
407+
if err := attributevalue.Unmarshal(value, &str); err == nil {
408+
return []byte(str), nil
352409
}
353410

354-
return jsoniterator.ConfigFastest.MarshalToString(v)
411+
return nil, fmt.Errorf("unsupported attribute value type %T", value)
355412
}
356413

357414
// Parse and process ttlInSeconds.
@@ -410,21 +467,11 @@ func (d *StateStore) Multi(ctx context.Context, request *state.TransactionalStat
410467
twi := types.TransactWriteItem{}
411468
switch req := o.(type) {
412469
case state.SetRequest:
413-
value, err := d.marshalToString(req.Value)
470+
pd, err := d.createPutData(&req)
414471
if err != nil {
415472
return fmt.Errorf("dynamodb error: failed to marshal value for key %s: %w", req.Key, err)
416473
}
417-
twi.Put = &types.Put{
418-
TableName: ptr.Of(d.table),
419-
Item: map[string]types.AttributeValue{
420-
d.partitionKey: &types.AttributeValueMemberS{
421-
Value: req.Key,
422-
},
423-
"value": &types.AttributeValueMemberS{
424-
Value: value,
425-
},
426-
},
427-
}
474+
twi.Put = pd.ToPut()
428475

429476
case state.DeleteRequest:
430477
twi.Delete = &types.Delete{

state/aws/dynamodb/dynamodb_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,51 @@ func TestSet(t *testing.T) {
499499
require.NoError(t, err)
500500
})
501501

502+
t.Run("Successfully set item with binary value", func(t *testing.T) {
503+
mockedDB := &awsAuth.MockDynamoDB{
504+
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {
505+
assert.Equal(t, dynamodb.AttributeValue{
506+
S: aws.String("key"),
507+
}, *input.Item["key"])
508+
assert.Equal(t, dynamodb.AttributeValue{
509+
B: []byte("value"),
510+
}, *input.Item["value"])
511+
assert.Len(t, input.Item, 3)
512+
513+
return &dynamodb.PutItemOutput{
514+
Attributes: map[string]*dynamodb.AttributeValue{
515+
"key": {
516+
S: aws.String("value"),
517+
},
518+
},
519+
}, nil
520+
},
521+
}
522+
523+
dynamo := awsAuth.DynamoDBClients{
524+
DynamoDB: mockedDB,
525+
}
526+
527+
mockedClients := awsAuth.Clients{
528+
Dynamo: &dynamo,
529+
}
530+
531+
mockAuthProvider := &awsAuth.StaticAuth{}
532+
mockAuthProvider.WithMockClients(&mockedClients)
533+
s := StateStore{
534+
authProvider: mockAuthProvider,
535+
partitionKey: defaultPartitionKeyName,
536+
}
537+
538+
req := &state.SetRequest{
539+
Key: "key",
540+
Value: []byte("value"),
541+
}
542+
err := s.Set(context.Background(), req)
543+
544+
require.NoError(t, err)
545+
})
546+
502547
t.Run("Successfully set item with matching etag", func(t *testing.T) {
503548
mockedDB := &awsAuth.MockDynamoDB{
504549
PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) {

0 commit comments

Comments
 (0)