Skip to content

Commit d71677f

Browse files
authored
refactor: consolidate once-only registration of extras (#85)
## Why this should be merged Consolidates duplicated logic. Similar rationale to #84. ## How this works New `register.AtMostOnce[T]` type is responsible for limiting calls to `Register()`. ## How this was tested Existing unit tests of `params`. Note that the equivalent functionality in `types` wasn't tested but now is.
1 parent 25e5ca3 commit d71677f

File tree

6 files changed

+111
-48
lines changed

6 files changed

+111
-48
lines changed

core/types/rlp_payload.libevm.go

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"io"
2222

2323
"github.com/ava-labs/libevm/libevm/pseudo"
24+
"github.com/ava-labs/libevm/libevm/register"
2425
"github.com/ava-labs/libevm/libevm/testonly"
2526
"github.com/ava-labs/libevm/rlp"
2627
)
@@ -37,18 +38,15 @@ import (
3738
// The payload can be accessed via the [ExtraPayloads.FromPayloadCarrier] method
3839
// of the accessor returned by RegisterExtras.
3940
func RegisterExtras[SA any]() ExtraPayloads[SA] {
40-
if registeredExtras != nil {
41-
panic("re-registration of Extras")
42-
}
4341
var extra ExtraPayloads[SA]
44-
registeredExtras = &extraConstructors{
42+
registeredExtras.MustRegister(&extraConstructors{
4543
stateAccountType: func() string {
4644
var x SA
4745
return fmt.Sprintf("%T", x)
4846
}(),
4947
newStateAccount: pseudo.NewConstructor[SA]().Zero,
5048
cloneStateAccount: extra.cloneStateAccount,
51-
}
49+
})
5250
return extra
5351
}
5452

@@ -59,12 +57,10 @@ func RegisterExtras[SA any]() ExtraPayloads[SA] {
5957
// defer-called afterwards, either directly or via testing.TB.Cleanup(). This is
6058
// a workaround for the single-call limitation on [RegisterExtras].
6159
func TestOnlyClearRegisteredExtras() {
62-
testonly.OrPanic(func() {
63-
registeredExtras = nil
64-
})
60+
registeredExtras.TestOnlyClear()
6561
}
6662

67-
var registeredExtras *extraConstructors
63+
var registeredExtras register.AtMostOnce[*extraConstructors]
6864

6965
type extraConstructors struct {
7066
stateAccountType string
@@ -74,10 +70,10 @@ type extraConstructors struct {
7470

7571
func (e *StateAccountExtra) clone() *StateAccountExtra {
7672
switch r := registeredExtras; {
77-
case r == nil, e == nil:
73+
case !r.Registered(), e == nil:
7874
return nil
7975
default:
80-
return r.cloneStateAccount(e)
76+
return r.Get().cloneStateAccount(e)
8177
}
8278
}
8379

@@ -146,15 +142,15 @@ func (a *SlimAccount) extra() *StateAccountExtra {
146142
func getOrSetNewStateAccountExtra(curr **StateAccountExtra) *StateAccountExtra {
147143
if *curr == nil {
148144
*curr = &StateAccountExtra{
149-
t: registeredExtras.newStateAccount(),
145+
t: registeredExtras.Get().newStateAccount(),
150146
}
151147
}
152148
return *curr
153149
}
154150

155151
func (e *StateAccountExtra) payload() *pseudo.Type {
156152
if e.t == nil {
157-
e.t = registeredExtras.newStateAccount()
153+
e.t = registeredExtras.Get().newStateAccount()
158154
}
159155
return e.t
160156
}
@@ -196,24 +192,24 @@ var _ interface {
196192
// EncodeRLP implements the [rlp.Encoder] interface.
197193
func (e *StateAccountExtra) EncodeRLP(w io.Writer) error {
198194
switch r := registeredExtras; {
199-
case r == nil:
195+
case !r.Registered():
200196
return nil
201197
case e == nil:
202198
e = &StateAccountExtra{}
203199
fallthrough
204200
case e.t == nil:
205-
e.t = r.newStateAccount()
201+
e.t = r.Get().newStateAccount()
206202
}
207203
return e.t.EncodeRLP(w)
208204
}
209205

210206
// DecodeRLP implements the [rlp.Decoder] interface.
211207
func (e *StateAccountExtra) DecodeRLP(s *rlp.Stream) error {
212208
switch r := registeredExtras; {
213-
case r == nil:
209+
case !r.Registered():
214210
return nil
215211
case e.t == nil:
216-
e.t = r.newStateAccount()
212+
e.t = r.Get().newStateAccount()
217213
fallthrough
218214
default:
219215
return s.Decode(e.t)
@@ -224,10 +220,10 @@ func (e *StateAccountExtra) DecodeRLP(s *rlp.Stream) error {
224220
func (e *StateAccountExtra) Format(s fmt.State, verb rune) {
225221
var out string
226222
switch r := registeredExtras; {
227-
case r == nil:
223+
case !r.Registered():
228224
out = "<nil>"
229225
case e == nil, e.t == nil:
230-
out = fmt.Sprintf("<nil>[*StateAccountExtra[%s]]", r.stateAccountType)
226+
out = fmt.Sprintf("<nil>[*StateAccountExtra[%s]]", r.Get().stateAccountType)
231227
default:
232228
e.t.Format(s, verb)
233229
return

libevm/register/register.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright 2024 the libevm authors.
2+
//
3+
// The libevm additions to go-ethereum are free software: you can redistribute
4+
// them and/or modify them under the terms of the GNU Lesser General Public License
5+
// as published by the Free Software Foundation, either version 3 of the License,
6+
// or (at your option) any later version.
7+
//
8+
// The libevm additions are distributed in the hope that they will be useful,
9+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser
11+
// General Public License for more details.
12+
//
13+
// You should have received a copy of the GNU Lesser General Public License
14+
// along with the go-ethereum library. If not, see
15+
// <http://www.gnu.org/licenses/>.
16+
17+
// Package register provides functionality for optional registration of types.
18+
package register
19+
20+
import (
21+
"errors"
22+
23+
"github.com/ava-labs/libevm/libevm/testonly"
24+
)
25+
26+
// An AtMostOnce allows zero or one registration of a T.
27+
type AtMostOnce[T any] struct {
28+
v *T
29+
}
30+
31+
// ErrReRegistration is returned on all but the first of calls to
32+
// [AtMostOnce.Register].
33+
var ErrReRegistration = errors.New("re-registration")
34+
35+
// Register registers `v` or returns [ErrReRegistration] if already called.
36+
func (o *AtMostOnce[T]) Register(v T) error {
37+
if o.Registered() {
38+
return ErrReRegistration
39+
}
40+
o.v = &v
41+
return nil
42+
}
43+
44+
// MustRegister is equivalent to [AtMostOnce.Register], panicking on error.
45+
func (o *AtMostOnce[T]) MustRegister(v T) {
46+
if err := o.Register(v); err != nil {
47+
panic(err)
48+
}
49+
}
50+
51+
// Registered reports whether [AtMostOnce.Register] has been called.
52+
func (o *AtMostOnce[T]) Registered() bool {
53+
return o.v != nil
54+
}
55+
56+
// Get returns the registered value. It MUST NOT be called before
57+
// [AtMostOnce.Register].
58+
func (o *AtMostOnce[T]) Get() T {
59+
return *o.v
60+
}
61+
62+
// TestOnlyClear clears any previously registered value, returning `o` to its
63+
// default state. It panics if called from a non-testing call stack.
64+
func (o *AtMostOnce[T]) TestOnlyClear() {
65+
testonly.OrPanic(func() {
66+
o.v = nil
67+
})
68+
}

params/config.libevm.go

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
"reflect"
2323

2424
"github.com/ava-labs/libevm/libevm/pseudo"
25-
"github.com/ava-labs/libevm/libevm/testonly"
25+
"github.com/ava-labs/libevm/libevm/register"
2626
)
2727

2828
// Extras are arbitrary payloads to be added as extra fields in [ChainConfig]
@@ -68,20 +68,17 @@ type Extras[C ChainConfigHooks, R RulesHooks] struct {
6868
// alter Ethereum behaviour; if this isn't desired then they can embed
6969
// [NOOPHooks] to satisfy either interface.
7070
func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPayloads[C, R] {
71-
if registeredExtras != nil {
72-
panic("re-registration of Extras")
73-
}
7471
mustBeStructOrPointerToOne[C]()
7572
mustBeStructOrPointerToOne[R]()
7673

7774
payloads := e.payloads()
78-
registeredExtras = &extraConstructors{
75+
registeredExtras.MustRegister(&extraConstructors{
7976
newChainConfig: pseudo.NewConstructor[C]().Zero,
8077
newRules: pseudo.NewConstructor[R]().Zero,
8178
reuseJSONRoot: e.ReuseJSONRoot,
8279
newForRules: e.newForRules,
8380
payloads: payloads,
84-
}
81+
})
8582
return payloads
8683
}
8784

@@ -92,14 +89,12 @@ func RegisterExtras[C ChainConfigHooks, R RulesHooks](e Extras[C, R]) ExtraPaylo
9289
// defer-called afterwards, either directly or via testing.TB.Cleanup(). This is
9390
// a workaround for the single-call limitation on [RegisterExtras].
9491
func TestOnlyClearRegisteredExtras() {
95-
testonly.OrPanic(func() {
96-
registeredExtras = nil
97-
})
92+
registeredExtras.TestOnlyClear()
9893
}
9994

10095
// registeredExtras holds non-generic constructors for the [Extras] types
10196
// registered via [RegisterExtras].
102-
var registeredExtras *extraConstructors
97+
var registeredExtras register.AtMostOnce[*extraConstructors]
10398

10499
type extraConstructors struct {
105100
newChainConfig, newRules func() *pseudo.Type
@@ -115,7 +110,7 @@ type extraConstructors struct {
115110

116111
func (e *Extras[C, R]) newForRules(c *ChainConfig, r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) *pseudo.Type {
117112
if e.NewRules == nil {
118-
return registeredExtras.newRules()
113+
return registeredExtras.Get().newRules()
119114
}
120115
rExtra := e.NewRules(c, r, e.payloads().FromChainConfig(c), blockNum, isMerge, timestamp)
121116
return pseudo.From(rExtra).Type
@@ -209,36 +204,36 @@ func (e ExtraPayloads[C, R]) hooksFromRules(r *Rules) RulesHooks {
209204
// abstract the libevm-specific behaviour outside of original geth code.
210205
func (c *ChainConfig) addRulesExtra(r *Rules, blockNum *big.Int, isMerge bool, timestamp uint64) {
211206
r.extra = nil
212-
if registeredExtras != nil {
213-
r.extra = registeredExtras.newForRules(c, r, blockNum, isMerge, timestamp)
207+
if registeredExtras.Registered() {
208+
r.extra = registeredExtras.Get().newForRules(c, r, blockNum, isMerge, timestamp)
214209
}
215210
}
216211

217212
// extraPayload returns the ChainConfig's extra payload iff [RegisterExtras] has
218213
// already been called. If the payload hasn't been populated (typically via
219214
// unmarshalling of JSON), a nil value is constructed and returned.
220215
func (c *ChainConfig) extraPayload() *pseudo.Type {
221-
if registeredExtras == nil {
216+
if !registeredExtras.Registered() {
222217
// This will only happen if someone constructs an [ExtraPayloads]
223218
// directly, without a call to [RegisterExtras].
224219
//
225220
// See https://google.github.io/styleguide/go/best-practices#when-to-panic
226221
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", c))
227222
}
228223
if c.extra == nil {
229-
c.extra = registeredExtras.newChainConfig()
224+
c.extra = registeredExtras.Get().newChainConfig()
230225
}
231226
return c.extra
232227
}
233228

234229
// extraPayload is equivalent to [ChainConfig.extraPayload].
235230
func (r *Rules) extraPayload() *pseudo.Type {
236-
if registeredExtras == nil {
231+
if !registeredExtras.Registered() {
237232
// See ChainConfig.extraPayload() equivalent.
238233
panic(fmt.Sprintf("%T.ExtraPayload() called before RegisterExtras()", r))
239234
}
240235
if r.extra == nil {
241-
r.extra = registeredExtras.newRules()
236+
r.extra = registeredExtras.Get().newRules()
242237
}
243238
return r.extra
244239
}

params/config.libevm_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/stretchr/testify/require"
2525

2626
"github.com/ava-labs/libevm/libevm/pseudo"
27+
"github.com/ava-labs/libevm/libevm/register"
2728
)
2829

2930
type rawJSON struct {
@@ -255,18 +256,21 @@ func TestExtrasPanic(t *testing.T) {
255256
t, func() {
256257
RegisterExtras(Extras[struct{ ChainConfigHooks }, struct{ RulesHooks }]{})
257258
},
258-
"re-registration",
259+
register.ErrReRegistration.Error(),
259260
)
260261
}
261262

262263
func assertPanics(t *testing.T, fn func(), wantContains string) {
263264
t.Helper()
264265
defer func() {
266+
t.Helper()
265267
switch r := recover().(type) {
266268
case nil:
267-
t.Error("function did not panic as expected")
269+
t.Error("function did not panic when panic expected")
268270
case string:
269271
assert.Contains(t, r, wantContains)
272+
case error:
273+
assert.Contains(t, r.Error(), wantContains)
270274
default:
271275
t.Fatalf("BAD TEST SETUP: recover() got unsupported type %T", r)
272276
}

params/hooks.libevm.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,17 @@ type RulesAllowlistHooks interface {
6969
// Hooks returns the hooks registered with [RegisterExtras], or [NOOPHooks] if
7070
// none were registered.
7171
func (c *ChainConfig) Hooks() ChainConfigHooks {
72-
if e := registeredExtras; e != nil {
73-
return e.payloads.hooksFromChainConfig(c)
72+
if e := registeredExtras; e.Registered() {
73+
return e.Get().payloads.hooksFromChainConfig(c)
7474
}
7575
return NOOPHooks{}
7676
}
7777

7878
// Hooks returns the hooks registered with [RegisterExtras], or [NOOPHooks] if
7979
// none were registered.
8080
func (r *Rules) Hooks() RulesHooks {
81-
if e := registeredExtras; e != nil {
82-
return e.payloads.hooksFromRules(r)
81+
if e := registeredExtras; e.Registered() {
82+
return e.Get().payloads.hooksFromRules(r)
8383
}
8484
return NOOPHooks{}
8585
}

params/json.libevm.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ type chainConfigWithExportedExtra struct {
4242
// UnmarshalJSON implements the [json.Unmarshaler] interface.
4343
func (c *ChainConfig) UnmarshalJSON(data []byte) error {
4444
switch reg := registeredExtras; {
45-
case reg != nil && !reg.reuseJSONRoot:
45+
case reg.Registered() && !reg.Get().reuseJSONRoot:
4646
return c.unmarshalJSONWithExtra(data)
4747

48-
case reg != nil && reg.reuseJSONRoot: // although the latter is redundant, it's clearer
49-
c.extra = reg.newChainConfig()
48+
case reg.Registered() && reg.Get().reuseJSONRoot: // although the latter is redundant, it's clearer
49+
c.extra = reg.Get().newChainConfig()
5050
if err := json.Unmarshal(data, c.extra); err != nil {
5151
c.extra = nil
5252
return err
@@ -63,7 +63,7 @@ func (c *ChainConfig) UnmarshalJSON(data []byte) error {
6363
func (c *ChainConfig) unmarshalJSONWithExtra(data []byte) error {
6464
cc := &chainConfigWithExportedExtra{
6565
chainConfigWithoutMethods: (*chainConfigWithoutMethods)(c),
66-
Extra: registeredExtras.newChainConfig(),
66+
Extra: registeredExtras.Get().newChainConfig(),
6767
}
6868
if err := json.Unmarshal(data, cc); err != nil {
6969
return err
@@ -75,10 +75,10 @@ func (c *ChainConfig) unmarshalJSONWithExtra(data []byte) error {
7575
// MarshalJSON implements the [json.Marshaler] interface.
7676
func (c *ChainConfig) MarshalJSON() ([]byte, error) {
7777
switch reg := registeredExtras; {
78-
case reg == nil:
78+
case !reg.Registered():
7979
return json.Marshal((*chainConfigWithoutMethods)(c))
8080

81-
case !reg.reuseJSONRoot:
81+
case !reg.Get().reuseJSONRoot:
8282
return c.marshalJSONWithExtra()
8383

8484
default: // reg.reuseJSONRoot == true

0 commit comments

Comments
 (0)