|
1 | 1 | from abc import abstractmethod |
2 | 2 |
|
3 | | -import numba |
4 | | -from numba import float32, float64 |
5 | | -from numba.experimental import jitclass |
6 | | - |
7 | | - |
8 | | -def spec_to_float32(spec): |
9 | | - """Convert a numba specification to an equivalent float32 one. |
10 | | -
|
11 | | - Parameters |
12 | | - ---------- |
13 | | - spec : list |
14 | | - A list of (name, dtype) for every attribute of a jitclass. |
15 | | -
|
16 | | - Returns |
17 | | - ------- |
18 | | - spec32 : list |
19 | | - A list of (name, dtype) for every attribute of a jitclass, where float64 |
20 | | - have been replaced by float32. |
21 | | - """ |
22 | | - spec32 = [] |
23 | | - for name, dtype in spec: |
24 | | - if dtype == float64: |
25 | | - dtype32 = float32 |
26 | | - elif isinstance(dtype, numba.core.types.npytypes.Array): |
27 | | - dtype32 = dtype.copy(dtype=float32) |
28 | | - else: |
29 | | - raise ValueError(f"Unknown spec type {dtype}") |
30 | | - spec32.append((name, dtype32)) |
31 | | - return spec32 |
32 | | - |
33 | | - |
34 | | -def jit_factory(Datafit, spec): |
35 | | - """JIT-compile a datafit class in float32 and float64 contexts. |
36 | | -
|
37 | | - Parameters |
38 | | - ---------- |
39 | | - Datafit : datafit class, inheriting from BaseDatafit |
40 | | - A datafit class, to be compiled. |
41 | | -
|
42 | | - spec : list |
43 | | - A list of type specifications for every attribute of Datafit. |
44 | | -
|
45 | | - Returns |
46 | | - ------- |
47 | | - Datafit_64 : Jitclass |
48 | | - A compiled datafit class with attribute types float64. |
49 | | -
|
50 | | - Datafit_32 : Jitclass |
51 | | - A compiled datafit class with attribute types float32. |
52 | | - """ |
53 | | - spec32 = spec_to_float32(spec) |
54 | | - return jitclass(spec)(Datafit), jitclass(spec32)(Datafit) |
55 | | - |
56 | 3 |
|
57 | 4 | class BaseDatafit(): |
58 | 5 | """Base class for datafits.""" |
59 | 6 |
|
| 7 | + @abstractmethod |
| 8 | + def get_spec(self): |
| 9 | + """Specify the numba types of the class attributes. |
| 10 | +
|
| 11 | + Returns |
| 12 | + ------- |
| 13 | + spec: Tuple of (attribute_name, dtype) |
| 14 | + spec to be passed to Numba jitclass to compile the class. |
| 15 | + """ |
| 16 | + |
| 17 | + @abstractmethod |
| 18 | + def params_to_dict(self): |
| 19 | + """Get the parameters to initialize an instance of the class. |
| 20 | +
|
| 21 | + Returns |
| 22 | + ------- |
| 23 | + dict_of_params : dict |
| 24 | + The parameters to instantiate an object of the class. |
| 25 | + """ |
| 26 | + |
60 | 27 | @abstractmethod |
61 | 28 | def initialize(self, X, y): |
62 | 29 | """Pre-computations before fitting on X and y. |
@@ -172,6 +139,26 @@ def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j): |
172 | 139 | class BaseMultitaskDatafit(): |
173 | 140 | """Base class for multitask datafits.""" |
174 | 141 |
|
| 142 | + @abstractmethod |
| 143 | + def get_spec(self): |
| 144 | + """Specify the numba types of the class attributes. |
| 145 | +
|
| 146 | + Returns |
| 147 | + ------- |
| 148 | + spec: Tuple of (attribute_name, dtype) |
| 149 | + spec to be passed to Numba jitclass to compile the class. |
| 150 | + """ |
| 151 | + |
| 152 | + @abstractmethod |
| 153 | + def params_to_dict(self): |
| 154 | + """Get the parameters to initialize an instance of the class. |
| 155 | +
|
| 156 | + Returns |
| 157 | + ------- |
| 158 | + dict_of_params : dict |
| 159 | + The parameters to instantiate an object of the class. |
| 160 | + """ |
| 161 | + |
175 | 162 | @abstractmethod |
176 | 163 | def initialize(self, X, Y): |
177 | 164 | """Store useful values before fitting on X and Y. |
|
0 commit comments