@@ -40,8 +40,39 @@ type (
4040 ProcessPipelineHook func (ctx context.Context , cmds []Cmder ) error
4141)
4242
43+ var (
44+ nonDialHook = func (ctx context.Context , network , addr string ) (net.Conn , error ) { return nil , nil }
45+ nonProcessHook = func (ctx context.Context , cmd Cmder ) error { return nil }
46+ nonProcessPipelineHook = func (ctx context.Context , cmds []Cmder ) error { return nil }
47+ nonTxProcessPipelineHook = func (ctx context.Context , cmds []Cmder ) error { return nil }
48+ )
49+
50+ type defaultHook struct {
51+ dial DialHook
52+ process ProcessHook
53+ pipeline ProcessPipelineHook
54+ txPipeline ProcessPipelineHook
55+ }
56+
57+ func (h * defaultHook ) init () {
58+ if h .dial == nil {
59+ h .dial = nonDialHook
60+ }
61+ if h .process == nil {
62+ h .process = nonProcessHook
63+ }
64+ if h .pipeline == nil {
65+ h .pipeline = nonProcessPipelineHook
66+ }
67+ if h .txPipeline == nil {
68+ h .txPipeline = nonTxProcessPipelineHook
69+ }
70+ }
71+
4372type hooks struct {
44- slice []Hook
73+ slice []Hook
74+ defaultHook defaultHook
75+
4576 dialHook DialHook
4677 processHook ProcessHook
4778 processPipelineHook ProcessPipelineHook
@@ -87,55 +118,45 @@ type hooks struct {
87118// if "next(ctx, cmd)" is not executed in hook-1, the redis command will not be executed.
88119func (hs * hooks ) AddHook (hook Hook ) {
89120 hs .slice = append (hs .slice , hook )
90- hs .dialHook = hook .DialHook (hs .dialHook )
91- hs .processHook = hook .ProcessHook (hs .processHook )
92- hs .processPipelineHook = hook .ProcessPipelineHook (hs .processPipelineHook )
93- hs .processTxPipelineHook = hook .ProcessPipelineHook (hs .processTxPipelineHook )
121+ hs .chain ()
94122}
95123
96- func (hs * hooks ) clone () hooks {
97- clone := * hs
98- l := len (clone .slice )
99- clone .slice = clone .slice [:l :l ]
100- return clone
101- }
124+ func (hs * hooks ) chain () {
125+ hs .defaultHook .init ()
102126
103- func (hs * hooks ) setDial (dial DialHook ) {
104- hs .dialHook = dial
105- for _ , h := range hs .slice {
106- if wrapped := h .DialHook (hs .dialHook ); wrapped != nil {
127+ hs .dialHook = hs .defaultHook .dial
128+ hs .processHook = hs .defaultHook .process
129+ hs .processPipelineHook = hs .defaultHook .pipeline
130+ hs .processTxPipelineHook = hs .defaultHook .txPipeline
131+
132+ for i := len (hs .slice ) - 1 ; i >= 0 ; i -- {
133+ if wrapped := hs .slice [i ].DialHook (hs .dialHook ); wrapped != nil {
107134 hs .dialHook = wrapped
108135 }
109- }
110- }
111-
112- func (hs * hooks ) setProcess (process ProcessHook ) {
113- hs .processHook = process
114- for _ , h := range hs .slice {
115- if wrapped := h .ProcessHook (hs .processHook ); wrapped != nil {
136+ if wrapped := hs .slice [i ].ProcessHook (hs .processHook ); wrapped != nil {
116137 hs .processHook = wrapped
117138 }
118- }
119- }
120-
121- func (hs * hooks ) setProcessPipeline (processPipeline ProcessPipelineHook ) {
122- hs .processPipelineHook = processPipeline
123- for _ , h := range hs .slice {
124- if wrapped := h .ProcessPipelineHook (hs .processPipelineHook ); wrapped != nil {
139+ if wrapped := hs .slice [i ].ProcessPipelineHook (hs .processPipelineHook ); wrapped != nil {
125140 hs .processPipelineHook = wrapped
126141 }
127- }
128- }
129-
130- func (hs * hooks ) setProcessTxPipeline (processTxPipeline ProcessPipelineHook ) {
131- hs .processTxPipelineHook = processTxPipeline
132- for _ , h := range hs .slice {
133- if wrapped := h .ProcessPipelineHook (hs .processTxPipelineHook ); wrapped != nil {
142+ if wrapped := hs .slice [i ].ProcessPipelineHook (hs .processTxPipelineHook ); wrapped != nil {
134143 hs .processTxPipelineHook = wrapped
135144 }
136145 }
137146}
138147
148+ func (hs * hooks ) clone () hooks {
149+ clone := * hs
150+ l := len (clone .slice )
151+ clone .slice = clone .slice [:l :l ]
152+ return clone
153+ }
154+
155+ func (hs * hooks ) setDefaultHook (d defaultHook ) {
156+ hs .defaultHook = d
157+ hs .chain ()
158+ }
159+
139160func (hs * hooks ) withProcessHook (ctx context.Context , cmd Cmder , hook ProcessHook ) error {
140161 for _ , h := range hs .slice {
141162 if wrapped := h .ProcessHook (hook ); wrapped != nil {
@@ -595,10 +616,12 @@ func NewClient(opt *Options) *Client {
595616
596617func (c * Client ) init () {
597618 c .cmdable = c .Process
598- c .hooks .setDial (c .baseClient .dial )
599- c .hooks .setProcess (c .baseClient .process )
600- c .hooks .setProcessPipeline (c .baseClient .processPipeline )
601- c .hooks .setProcessTxPipeline (c .baseClient .processTxPipeline )
619+ c .hooks .setDefaultHook (defaultHook {
620+ dial : c .baseClient .dial ,
621+ process : c .baseClient .process ,
622+ pipeline : c .baseClient .processPipeline ,
623+ txPipeline : c .baseClient .processTxPipeline ,
624+ })
602625}
603626
604627func (c * Client ) WithTimeout (timeout time.Duration ) * Client {
@@ -755,11 +778,12 @@ func newConn(opt *Options, connPool pool.Pooler) *Conn {
755778
756779 c .cmdable = c .Process
757780 c .statefulCmdable = c .Process
758-
759- c .hooks .setDial (c .baseClient .dial )
760- c .hooks .setProcess (c .baseClient .process )
761- c .hooks .setProcessPipeline (c .baseClient .processPipeline )
762- c .hooks .setProcessTxPipeline (c .baseClient .processTxPipeline )
781+ c .hooks .setDefaultHook (defaultHook {
782+ dial : c .baseClient .dial ,
783+ process : c .baseClient .process ,
784+ pipeline : c .baseClient .processPipeline ,
785+ txPipeline : c .baseClient .processTxPipeline ,
786+ })
763787
764788 return & c
765789}
0 commit comments