Skip to content

Commit fdd9c79

Browse files
authored
Add cached MapReduce operation (#143)
* Add cached MapReduce operation ``` === RUN TestMapReduceSimple mapreduce_test.go:69: tree size: 600000, cache size: 1000 mapreduce_test.go:76: fresh readCount: 36891 mapreduce_test.go:83: fresh re-readCount: 0 mapreduce_test.go:103: new key readCount: 38 mapreduce_test.go:114: repeat readCount: 0 mapreduce_test.go:121: repeat readCount: 0 mapreduce_test.go:141: new two keys readCount: 76 Signed-off-by: Jakub Sztandera <oss@kubuxu.com> --- PASS: TestMapReduceSimple (6.89s) ``` Signed-off-by: Jakub Sztandera <oss@kubuxu.com> * reuse reader Signed-off-by: Jakub Sztandera <oss@kubuxu.com> * Change signature Signed-off-by: Jakub Sztandera <oss@kubuxu.com> * cleanup imports Signed-off-by: Jakub Sztandera <oss@kubuxu.com> * make test a bit more deterministic Signed-off-by: Jakub Sztandera <oss@kubuxu.com> --------- Signed-off-by: Jakub Sztandera <oss@kubuxu.com>
1 parent 9f2472e commit fdd9c79

File tree

3 files changed

+375
-0
lines changed

3 files changed

+375
-0
lines changed

hamt_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,35 @@ func TestSha256(t *testing.T) {
531531
}))
532532
}
533533

534+
func TestForEach(t *testing.T) {
535+
ctx := context.Background()
536+
cs := cbor.NewCborStore(newMockBlocks())
537+
begn, err := NewNode(cs)
538+
require.NoError(t, err)
539+
540+
golden := make(map[string]*CborByteArray)
541+
for range 1000 {
542+
k := randKey()
543+
v := randValue()
544+
golden[k] = v
545+
err = begn.Set(ctx, k, v)
546+
require.NoError(t, err)
547+
}
548+
err = begn.Flush(ctx)
549+
require.NoError(t, err)
550+
err = begn.ForEach(ctx, func(k string, val *cbg.Deferred) error {
551+
v, ok := golden[k]
552+
if !ok {
553+
t.Fatalf("unexpected key in ForEach: %s", k)
554+
}
555+
var val2 CborByteArray
556+
val2.UnmarshalCBOR(bytes.NewReader(val.Raw))
557+
require.Equal(t, []byte(*v), []byte(val2))
558+
return nil
559+
})
560+
require.NoError(t, err)
561+
}
562+
534563
func testBasic(t *testing.T, options ...Option) {
535564
ctx := context.Background()
536565
cs := cbor.NewCborStore(newMockBlocks())

mapreduce.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
package hamt
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"errors"
7+
"fmt"
8+
"math/rand/v2"
9+
"sync"
10+
11+
"github.com/ipfs/go-cid"
12+
cbor "github.com/ipfs/go-ipld-cbor"
13+
cbg "github.com/whyrusleeping/cbor-gen"
14+
xerrors "golang.org/x/xerrors"
15+
)
16+
17+
type cacheEntry[T any] struct {
18+
value T
19+
weight int
20+
}
21+
22+
type weigthted2RCache[T any] struct {
23+
lk sync.Mutex
24+
cache map[cid.Cid]cacheEntry[T]
25+
cacheSize int
26+
}
27+
28+
func newWeighted2RCache[T any](cacheSize int) *weigthted2RCache[T] {
29+
return &weigthted2RCache[T]{
30+
cache: make(map[cid.Cid]cacheEntry[T]),
31+
cacheSize: cacheSize,
32+
}
33+
}
34+
35+
func (c *weigthted2RCache[T]) Get(k cid.Cid) (cacheEntry[T], bool) {
36+
c.lk.Lock()
37+
defer c.lk.Unlock()
38+
v, ok := c.cache[k]
39+
if !ok {
40+
return v, false
41+
}
42+
return v, true
43+
}
44+
45+
func (c *weigthted2RCache[T]) Add(k cid.Cid, v cacheEntry[T]) {
46+
// dont cache nodes that require less than 6 reads
47+
if v.weight <= 5 {
48+
return
49+
}
50+
c.lk.Lock()
51+
defer c.lk.Unlock()
52+
if _, ok := c.cache[k]; ok {
53+
c.cache[k] = v
54+
return
55+
}
56+
57+
c.cache[k] = v
58+
if len(c.cache) > c.cacheSize {
59+
// pick two random entries using map iteration
60+
// this works well for cacheSize > 8
61+
var k1, k2 cid.Cid
62+
var v1, v2 cacheEntry[T]
63+
for k, v := range c.cache {
64+
k1 = k
65+
v1 = v
66+
break
67+
}
68+
for k, v := range c.cache {
69+
k2 = k
70+
v2 = v
71+
break
72+
}
73+
// pick random one based on weight
74+
r1 := rand.Float64()
75+
if r1 < float64(v1.weight)/float64(v1.weight+v2.weight) {
76+
delete(c.cache, k2)
77+
} else {
78+
delete(c.cache, k1)
79+
}
80+
}
81+
}
82+
83+
// CachedMapReduce is a map reduce implementation that caches intermediate results
84+
// to reduce the number of reads from the underlying store.
85+
type CachedMapReduce[T any, PT interface {
86+
*T
87+
cbg.CBORUnmarshaler
88+
}, U any] struct {
89+
mapper func(string, T) (U, error)
90+
reducer func([]U) (U, error)
91+
cache *weigthted2RCache[U]
92+
}
93+
94+
// NewCachedMapReduce creates a new CachedMapReduce instance.
95+
// The mapper translates a key-value pair stored in the HAMT into a chosen U value.
96+
// The reducer reduces the U values into a single U value.
97+
// The cacheSize parameter specifies the maximum number of intermediate results to cache.
98+
func NewCachedMapReduce[T any, PT interface {
99+
*T
100+
cbg.CBORUnmarshaler
101+
}, U any](
102+
mapper func(string, T) (U, error),
103+
reducer func([]U) (U, error),
104+
cacheSize int,
105+
) (*CachedMapReduce[T, PT, U], error) {
106+
return &CachedMapReduce[T, PT, U]{
107+
mapper: mapper,
108+
reducer: reducer,
109+
cache: newWeighted2RCache[U](cacheSize),
110+
}, nil
111+
}
112+
113+
// MapReduce applies the map reduce function to the given root node.
114+
func (cmr *CachedMapReduce[T, PT, U]) MapReduce(ctx context.Context, cs cbor.IpldStore, c cid.Cid, options ...Option) (U, error) {
115+
var res U
116+
root, err := LoadNode(ctx, cs, c, options...)
117+
if err != nil {
118+
return res, xerrors.Errorf("failed to load root node: %w", err)
119+
}
120+
121+
ce, err := cmr.mapReduceInternal(ctx, root)
122+
if err != nil {
123+
return res, err
124+
}
125+
return ce.value, nil
126+
}
127+
128+
func (cmr *CachedMapReduce[T, PT, U]) mapReduceInternal(ctx context.Context, node *Node) (cacheEntry[U], error) {
129+
var res cacheEntry[U]
130+
131+
Us := make([]U, 0)
132+
weight := 1
133+
for _, p := range node.Pointers {
134+
if p.cache != nil && p.dirty {
135+
return res, errors.New("cannot iterate over a dirty node")
136+
}
137+
if p.isShard() {
138+
if p.cache != nil && p.dirty {
139+
return res, errors.New("cannot iterate over a dirty node")
140+
}
141+
linkU, ok := cmr.cache.Get(p.Link)
142+
if !ok {
143+
chnd, err := p.loadChild(ctx, node.store, node.bitWidth, node.hash)
144+
if err != nil {
145+
return res, fmt.Errorf("loading child: %w", err)
146+
}
147+
148+
linkU, err = cmr.mapReduceInternal(ctx, chnd)
149+
if err != nil {
150+
return res, fmt.Errorf("map reduce child: %w", err)
151+
}
152+
cmr.cache.Add(p.Link, linkU)
153+
}
154+
Us = append(Us, linkU.value)
155+
weight += linkU.weight
156+
} else {
157+
reader := bytes.NewReader(nil)
158+
for _, v := range p.KVs {
159+
var pt = PT(new(T))
160+
reader.Reset(v.Value.Raw)
161+
err := pt.UnmarshalCBOR(reader)
162+
if err != nil {
163+
return res, fmt.Errorf("failed to unmarshal value: %w", err)
164+
}
165+
u, err := cmr.mapper(string(v.Key), *pt)
166+
if err != nil {
167+
return res, fmt.Errorf("failed to map value: %w", err)
168+
}
169+
170+
Us = append(Us, u)
171+
}
172+
}
173+
}
174+
175+
resU, err := cmr.reducer(Us)
176+
if err != nil {
177+
return res, fmt.Errorf("failed to reduce self values: %w", err)
178+
}
179+
180+
return cacheEntry[U]{
181+
value: resU,
182+
weight: weight,
183+
}, nil
184+
}

mapreduce_test.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package hamt
2+
3+
import (
4+
"context"
5+
"encoding/hex"
6+
"math/rand/v2"
7+
"slices"
8+
"strings"
9+
"testing"
10+
11+
cid "github.com/ipfs/go-cid"
12+
cbor "github.com/ipfs/go-ipld-cbor"
13+
"github.com/stretchr/testify/require"
14+
)
15+
16+
type readCounterStore struct {
17+
cbor.IpldStore
18+
readCount int
19+
}
20+
21+
func (rcs *readCounterStore) Get(ctx context.Context, c cid.Cid, out any) error {
22+
rcs.readCount++
23+
return rcs.IpldStore.Get(ctx, c, out)
24+
}
25+
26+
type deterministicKVGen struct {
27+
rng *rand.ChaCha8
28+
}
29+
30+
func (dkg *deterministicKVGen) GenKV() (string, *CborByteArray) {
31+
key := make([]byte, 18)
32+
dkg.rng.Read(key)
33+
val := CborByteArray(make([]byte, 30))
34+
dkg.rng.Read(val)
35+
return hex.EncodeToString(key), &val
36+
}
37+
38+
func TestMapReduceSimple(t *testing.T) {
39+
ctx := context.Background()
40+
opts := []Option{UseTreeBitWidth(5)}
41+
cs := &readCounterStore{cbor.NewCborStore(newMockBlocks()), 0}
42+
43+
gen := deterministicKVGen{rng: rand.NewChaCha8([32]byte{})}
44+
45+
N := 50000
46+
var rootCid cid.Cid
47+
golden := make(map[string]string)
48+
{
49+
begn, err := NewNode(cs, opts...)
50+
require.NoError(t, err)
51+
52+
for range N {
53+
k, v := gen.GenKV()
54+
golden[k] = string([]byte(*v))
55+
begn.Set(ctx, k, v)
56+
}
57+
58+
rootCid, err = begn.Write(ctx)
59+
require.NoError(t, err)
60+
}
61+
62+
type kv struct {
63+
k string
64+
v string
65+
}
66+
67+
mapper := func(k string, v CborByteArray) ([]kv, error) {
68+
return []kv{{k, string([]byte(v))}}, nil
69+
}
70+
reducer := func(kvs [][]kv) ([]kv, error) {
71+
var kvsConcat []kv
72+
for _, kvs := range kvs {
73+
kvsConcat = append(kvsConcat, kvs...)
74+
}
75+
slices.SortFunc(kvsConcat, func(a, b kv) int {
76+
return strings.Compare(a.k, b.k)
77+
})
78+
return kvsConcat, nil
79+
}
80+
81+
cmr, err := NewCachedMapReduce(mapper, reducer, int(N/300))
82+
t.Logf("tree size: %d, cache size: %d", N, cmr.cache.cacheSize)
83+
require.NoError(t, err)
84+
85+
cs.readCount = 0
86+
res, err := cmr.MapReduce(ctx, cs, rootCid, opts...)
87+
require.NoError(t, err)
88+
require.Equal(t, len(golden), len(res))
89+
t.Logf("fresh readCount: %d", cs.readCount)
90+
91+
cs.readCount = 0
92+
res, err = cmr.MapReduce(ctx, cs, rootCid, opts...)
93+
require.NoError(t, err)
94+
t.Logf("fresh re-readCount: %d", cs.readCount)
95+
require.Less(t, cs.readCount, 200)
96+
97+
verifyConsistency := func(res []kv) {
98+
t.Helper()
99+
mappedRes := make(map[string]string)
100+
for _, kv := range res {
101+
mappedRes[kv.k] = kv.v
102+
}
103+
require.Equal(t, len(golden), len(mappedRes))
104+
require.Equal(t, golden, mappedRes)
105+
}
106+
verifyConsistency(res)
107+
108+
{
109+
begn, err := LoadNode(ctx, cs, rootCid, opts...)
110+
require.NoError(t, err)
111+
// add new key
112+
k, v := gen.GenKV()
113+
golden[k] = string([]byte(*v))
114+
begn.Set(ctx, k, v)
115+
116+
rootCid, err = begn.Write(ctx)
117+
require.NoError(t, err)
118+
}
119+
120+
cs.readCount = 0
121+
res, err = cmr.MapReduce(ctx, cs, rootCid, opts...)
122+
require.NoError(t, err)
123+
verifyConsistency(res)
124+
t.Logf("new key readCount: %d", cs.readCount)
125+
require.Less(t, cs.readCount, 200)
126+
127+
cs.readCount = 0
128+
res, err = cmr.MapReduce(ctx, cs, rootCid, opts...)
129+
require.NoError(t, err)
130+
verifyConsistency(res)
131+
t.Logf("repeat readCount: %d", cs.readCount)
132+
require.Less(t, cs.readCount, 200)
133+
134+
cs.readCount = 0
135+
res, err = cmr.MapReduce(ctx, cs, rootCid, opts...)
136+
require.NoError(t, err)
137+
verifyConsistency(res)
138+
t.Logf("repeat readCount: %d", cs.readCount)
139+
require.Less(t, cs.readCount, 200)
140+
141+
{
142+
begn, err := LoadNode(ctx, cs, rootCid, opts...)
143+
require.NoError(t, err)
144+
// add two new keys
145+
k, v := gen.GenKV()
146+
golden[k] = string([]byte(*v))
147+
begn.Set(ctx, k, v)
148+
k, v = gen.GenKV()
149+
golden[k] = string([]byte(*v))
150+
begn.Set(ctx, k, v)
151+
152+
rootCid, err = begn.Write(ctx)
153+
require.NoError(t, err)
154+
}
155+
156+
cs.readCount = 0
157+
res, err = cmr.MapReduce(ctx, cs, rootCid, opts...)
158+
require.NoError(t, err)
159+
verifyConsistency(res)
160+
t.Logf("new two keys readCount: %d", cs.readCount)
161+
require.Less(t, cs.readCount, 300)
162+
}

0 commit comments

Comments
 (0)