Skip to content

Commit acbd0c9

Browse files
committed
Refactored out dcite call to avoid defining __init__.
1 parent c2606c0 commit acbd0c9

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

src/qinfer/test_models.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656
import numpy as np
5757

58-
from .utils import binomial_pdf
58+
from .utils import binomial_pdf, decorate_init
5959
from .abstract_model import FiniteOutcomeModel, DifferentiableModel
6060
from ._due import due, Doi
6161

@@ -212,6 +212,13 @@ def score(self, outcomes, modelparams, expparams, return_L=False):
212212
else:
213213
return q
214214

215+
@decorate_init(
216+
due.dcite(
217+
Doi('10.1088/1367-2630/14/10/103013'),
218+
description='Robust online Hamiltonian learning',
219+
tags=['implementation']
220+
)
221+
)
215222
class UnknownT2Model(FiniteOutcomeModel):
216223
"""
217224
Describes the free evolution of a single qubit prepared in the
@@ -223,11 +230,6 @@ class UnknownT2Model(FiniteOutcomeModel):
223230
:modelparam T2_inv: The decoherence strength :math:`T_2^{-1}`.
224231
:scalar-expparam float: The evolution time :math:`t`.
225232
"""
226-
@due.dcite(
227-
Doi('10.1088/1367-2630/14/10/103013'),
228-
description='Robust online Hamiltonian learning',
229-
tags=['implementation']
230-
)
231233

232234
@property
233235
def n_modelparams(self): return 2

src/qinfer/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,20 @@ def sqrtm_psd(A, est_error=True, check_finite=True):
606606
else:
607607
return A_sqrt
608608

609+
def decorate_init(init_decorator):
610+
"""
611+
Given a class definition and a decorator that acts on methods,
612+
applies that decorator to the class' __init__ method.
613+
Useful for decorating __init__ while still allowing __init__ to be
614+
inherited.
615+
"""
616+
617+
def class_decorator(cls):
618+
cls.__init__ = init_decorator(cls.__init__)
619+
return cls
620+
621+
return class_decorator
622+
609623
#==============================================================================
610624
#Test Code
611625
if __name__ == "__main__":

0 commit comments

Comments
 (0)