Skip to content

Commit 8e958e9

Browse files
json: refactor ref-cycle handling
Also set UnsupportedValueError.Value (better stdlib compat).
1 parent e960845 commit 8e958e9

File tree

2 files changed

+79
-16
lines changed

2 files changed

+79
-16
lines changed

json/codec.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"sort"
1111
"strconv"
1212
"strings"
13+
"sync"
1314
"sync/atomic"
1415
"time"
1516
"unicode"
@@ -37,8 +38,19 @@ type encoder struct {
3738
// encoder starts tracking pointers it has seen as an attempt to detect
3839
// whether it has entered a pointer cycle and needs to error before the
3940
// goroutine runs out of stack space.
41+
//
42+
// This relies on encoder being passed as a value,
43+
// and encoder methods calling each other in a traditional stack
44+
// (not using trampoline techniques),
45+
// since ptrDepth is never decremented.
4046
ptrDepth uint32
41-
ptrSeen map[unsafe.Pointer]struct{}
47+
ptrSeen cycleMap
48+
}
49+
50+
type cycleMap map[unsafe.Pointer]struct{}
51+
52+
var cycleMapPool = sync.Pool{
53+
New: func() any { return make(cycleMap) },
4254
}
4355

4456
type decoder struct {

json/encode.go

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -794,22 +794,21 @@ func (e encoder) encodeEmbeddedStructPointer(b []byte, p unsafe.Pointer, t refle
794794
}
795795

796796
func (e encoder) encodePointer(b []byte, p unsafe.Pointer, t reflect.Type, encode encodeFunc) ([]byte, error) {
797-
if p = *(*unsafe.Pointer)(p); p != nil {
798-
if e.ptrDepth++; e.ptrDepth >= startDetectingCyclesAfter {
799-
if _, seen := e.ptrSeen[p]; seen {
800-
// TODO: reconstruct the reflect.Value from p + t so we can set
801-
// the erorr's Value field?
802-
return b, &UnsupportedValueError{Str: fmt.Sprintf("encountered a cycle via %s", t)}
803-
}
804-
if e.ptrSeen == nil {
805-
e.ptrSeen = make(map[unsafe.Pointer]struct{})
806-
}
807-
e.ptrSeen[p] = struct{}{}
808-
defer delete(e.ptrSeen, p)
809-
}
810-
return encode(e, b, p)
797+
// p was a pointer to the actual user data pointer:
798+
// dereference it to operate on the user data pointer.
799+
p = *(*unsafe.Pointer)(p)
800+
if p == nil {
801+
return e.encodeNull(b, nil)
802+
}
803+
804+
err := checkRefCycle(&e, t, p)
805+
if err != nil {
806+
return b, err
811807
}
812-
return e.encodeNull(b, nil)
808+
809+
defer freeRefCycleInfo(&e, p)
810+
811+
return encode(e, b, p)
813812
}
814813

815814
func (e encoder) encodeInterface(b []byte, p unsafe.Pointer) ([]byte, error) {
@@ -968,3 +967,55 @@ func appendCompactEscapeHTML(dst []byte, src []byte) []byte {
968967

969968
return dst
970969
}
970+
971+
// checkRefCycle returns an error if a reference cycle was detected.
972+
// The data pointer passed in should be equivalent to one of:
973+
//
974+
// - A normal Go pointer, e.g. `unsafe.Pointer(&T)`
975+
// - The pointer to a map header, e.g. `*(*unsafe.Pointer)(&map[K]V)`
976+
//
977+
// Many [encoder] methods accept a pointer-to-a-pointer,
978+
// and so those may need to be derenced in order to safely pass them here.
979+
func checkRefCycle(e *encoder, t reflect.Type, p unsafe.Pointer) error {
980+
e.ptrDepth++
981+
if e.ptrDepth < startDetectingCyclesAfter {
982+
return nil
983+
}
984+
985+
_, seen := e.ptrSeen[p]
986+
if seen {
987+
v := reflect.NewAt(t, p)
988+
return &UnsupportedValueError{
989+
Value: v,
990+
Str: fmt.Sprintf("encountered a cycle via %s", t),
991+
}
992+
}
993+
994+
if e.ptrSeen == nil {
995+
e.ptrSeen = cycleMapPool.Get().(cycleMap)
996+
}
997+
998+
e.ptrSeen[p] = struct{}{}
999+
1000+
return nil
1001+
}
1002+
1003+
// freeRefCycle performs the cleanup operation for [checkRefCycle].
1004+
// p must be the same value passed into a prior call to checkRefCycle.
1005+
func freeRefCycleInfo(e *encoder, p unsafe.Pointer) {
1006+
if e.ptrSeen == nil {
1007+
// The map hasn't yet been allocated (not enough recursion depth),
1008+
// so there's not any work to do in this function.
1009+
return
1010+
}
1011+
1012+
delete(e.ptrSeen, p)
1013+
if len(e.ptrSeen) != 0 {
1014+
// There are other keys in the map, so we can't release it into the pool.
1015+
return
1016+
}
1017+
1018+
m := e.ptrSeen
1019+
e.ptrSeen = nil
1020+
cycleMapPool.Put(m)
1021+
}

0 commit comments

Comments
 (0)