@@ -139,17 +139,27 @@ func (f fieldset) index(i int) (int, int) {
139139// ParseRewriteTemplate constructs a Rewriter for a protobuf type using the
140140// given json template to describe the rewrite rules.
141141//
142- // The json template contains a representation of the
143- func ParseRewriteTemplate (typ Type , jsonTemplate []byte ) (Rewriter , error ) {
142+ // The json template contains a representation of the message that is used as the
143+ // source values to overwrite in the protobuf targeted by the resulting rewriter.
144+ //
145+ // The rules are an optional set of RewriterRules that can provide alternative
146+ // Rewriters from the default used for the field type. These rules are given the
147+ // json.RawMessage bytes from the template, and they are expected to create a
148+ // Rewriter to be applied against the target protobuf.
149+ func ParseRewriteTemplate (typ Type , jsonTemplate []byte , rules ... RewriterRules ) (Rewriter , error ) {
144150 switch typ .Kind () {
145151 case Struct :
146- return parseRewriteTemplateStruct (typ , 0 , jsonTemplate )
152+ return parseRewriteTemplateStruct (typ , 0 , jsonTemplate , rules ... )
147153 default :
148154 return nil , fmt .Errorf ("cannot construct a rewrite template from a non-struct type %s" , typ .Name ())
149155 }
150156}
151157
152- func parseRewriteTemplate (t Type , f FieldNumber , j json.RawMessage ) (Rewriter , error ) {
158+ func parseRewriteTemplate (t Type , f FieldNumber , j json.RawMessage , rule any ) (Rewriter , error ) {
159+ if rwer , ok := rule .(Rewriterer ); ok {
160+ return rwer .Rewriter (t , f , j )
161+ }
162+
153163 switch t .Kind () {
154164 case Bool :
155165 return parseRewriteTemplateBool (t , f , j )
@@ -184,7 +194,11 @@ func parseRewriteTemplate(t Type, f FieldNumber, j json.RawMessage) (Rewriter, e
184194 case Map :
185195 return parseRewriteTemplateMap (t , f , j )
186196 case Struct :
187- return parseRewriteTemplateStruct (t , f , j )
197+ sub , n , ok := [1 ]RewriterRules {}, 0 , false
198+ if sub [0 ], ok = rule .(RewriterRules ); ok {
199+ n = 1
200+ }
201+ return parseRewriteTemplateStruct (t , f , j , sub [:n ]... )
188202 default :
189203 return nil , fmt .Errorf ("cannot construct a rewriter from type %s" , t .Name ())
190204 }
@@ -376,7 +390,7 @@ func parseRewriteTemplateMap(t Type, f FieldNumber, j json.RawMessage) (Rewriter
376390 return MultiRewriter (rewriters ... ), nil
377391}
378392
379- func parseRewriteTemplateStruct (t Type , f FieldNumber , j json.RawMessage ) (Rewriter , error ) {
393+ func parseRewriteTemplateStruct (t Type , f FieldNumber , j json.RawMessage , rules ... RewriterRules ) (Rewriter , error ) {
380394 template := map [string ]json.RawMessage {}
381395
382396 if err := json .Unmarshal (j , & template ); err != nil {
@@ -408,10 +422,18 @@ func parseRewriteTemplateStruct(t Type, f FieldNumber, j json.RawMessage) (Rewri
408422 fields = []json.RawMessage {v }
409423 }
410424
425+ var rule any
426+ for i := range rules {
427+ if r , ok := rules [i ][f.Name ]; ok {
428+ rule = r
429+ break
430+ }
431+ }
432+
411433 rewriters = rewriters [:0 ]
412434
413435 for _ , v := range fields {
414- rw , err := parseRewriteTemplate (f .Type , f .Number , v )
436+ rw , err := parseRewriteTemplate (f .Type , f .Number , v , rule )
415437 if err != nil {
416438 return nil , fmt .Errorf ("%s: %w" , k , err )
417439 }
@@ -462,3 +484,117 @@ func (f *embddedRewriter) Rewrite(out, in []byte) ([]byte, error) {
462484 copy (out [prefix :], b [:tagAndLen ])
463485 return out , nil
464486}
487+
488+ // RewriterRules defines a set of rules for overriding the Rewriter used for any
489+ // particular field. These maps may be nested for defining rules for struct members.
490+ //
491+ // For example:
492+ //
493+ // rules := proto.RewriterRules {
494+ // "flags": proto.BitOr[uint64]{},
495+ // "nested": proto.RewriterRules {
496+ // "name": myCustomRewriter,
497+ // },
498+ // }
499+ type RewriterRules map [string ]any
500+
501+ // Rewriterer is the interface for producing a Rewriter for a given Type, FieldNumber
502+ // and json.RawMessage. The JSON value is the JSON-encoded payload that should be
503+ // decoded to produce the appropriate Rewriter. Implementations of the Rewriterer
504+ // interface are added to the RewriterRules to specify the rules for performing
505+ // custom rewrite logic.
506+ type Rewriterer interface {
507+ Rewriter (Type , FieldNumber , json.RawMessage ) (Rewriter , error )
508+ }
509+
510+ // BitOr implments the Rewriterer interface for providing a bitwise-or rewrite
511+ // logic for integers rather than replacing them. Instances of this type are
512+ // zero-size, carrying only the generic type for creating the appropriate
513+ // Rewriter when requested.
514+ //
515+ // Adding these to a RewriterRules looks like:
516+ //
517+ // rules := proto.RewriterRules {
518+ // "flags": proto.BitOr[uint64]{},
519+ // }
520+ //
521+ // When used as a rule when rewriting from a template, the BitOr expects a JSON-
522+ // encoded integer passed into the Rewriter method. This parsed integer is then
523+ // used to perform a bitwise-or against the protobuf message that is being rewritten.
524+ //
525+ // The above example can then be used like:
526+ //
527+ // template := []byte(`{"flags": 8}`) // n |= 0b1000
528+ // rw, err := proto.ParseRewriteTemplate(typ, template, rules)
529+ type BitOr [T integer ] struct {}
530+
531+ // integer is the contraint used by the BitOr Rewriterer and the bitOrRW Rewriter.
532+ // Because these perform bitwise-or operations, the types must be integer-like.
533+ type integer interface {
534+ ~ int | ~ int32 | ~ int64 | ~ uint | ~ uint32 | ~ uint64
535+ }
536+
537+ // Rewriter implements the Rewriterer interface. The JSON value provided to this
538+ // method comes from the template used for rewriting. The returned Rewriter will use
539+ // this JSON-encoded integer to perform a bitwise-or against the protobuf message
540+ // that is being rewritten.
541+ func (BitOr [T ]) Rewriter (t Type , f FieldNumber , j json.RawMessage ) (Rewriter , error ) {
542+ var v T
543+ err := json .Unmarshal (j , & v )
544+ if err != nil {
545+ return nil , err
546+ }
547+ return BitOrRewriter (t , f , v )
548+ }
549+
550+ // BitOrRewriter creates a bitwise-or Rewriter for a given field type and number.
551+ // The mask is the value or'ed with values in the target protobuf.
552+ func BitOrRewriter [T integer ](t Type , f FieldNumber , mask T ) (Rewriter , error ) {
553+ switch t .Kind () {
554+ case Int32 , Int64 , Sint32 , Sint64 , Uint32 , Uint64 , Fix32 , Fix64 , Sfix32 , Sfix64 :
555+ default :
556+ return nil , fmt .Errorf ("cannot construct a rewriter from type %s" , t .Name ())
557+ }
558+ return bitOrRW [T ]{mask : mask , t : t , f : f }, nil
559+ }
560+
561+ // bitOrRW is the Rewriter returned by the BitOr Rewriter method.
562+ type bitOrRW [T integer ] struct {
563+ mask T
564+ t Type
565+ f FieldNumber
566+ }
567+
568+ // Rewrite implements the Rewriter interface performing a bitwise-or between the
569+ // template value and the input value.
570+ func (r bitOrRW [T ]) Rewrite (out , in []byte ) ([]byte , error ) {
571+ var v T
572+ if err := Unmarshal (in , & v ); err != nil {
573+ return nil , err
574+ }
575+
576+ v |= r .mask
577+
578+ switch r .t .Kind () {
579+ case Int32 :
580+ return r .f .Int32 (int32 (v )).Rewrite (out , in )
581+ case Int64 :
582+ return r .f .Int64 (int64 (v )).Rewrite (out , in )
583+ case Sint32 :
584+ return r .f .Uint32 (encodeZigZag32 (int32 (v ))).Rewrite (out , in )
585+ case Sint64 :
586+ return r .f .Uint64 (encodeZigZag64 (int64 (v ))).Rewrite (out , in )
587+ case Uint32 , Uint64 :
588+ return r .f .Uint64 (uint64 (v )).Rewrite (out , in )
589+ case Fix32 :
590+ return r .f .Fixed32 (uint32 (v )).Rewrite (out , in )
591+ case Fix64 :
592+ return r .f .Fixed64 (uint64 (v )).Rewrite (out , in )
593+ case Sfix32 :
594+ return r .f .Fixed32 (encodeZigZag32 (int32 (v ))).Rewrite (out , in )
595+ case Sfix64 :
596+ return r .f .Fixed64 (encodeZigZag64 (int64 (v ))).Rewrite (out , in )
597+ }
598+
599+ panic ("unreachable" ) // Kind is validated when creating instances
600+ }
0 commit comments