@@ -50,26 +50,45 @@ def apply_ufunc(
5050 if method not in __METHODS_FUNC__ :
5151 raise UnknownMethodError (method , __METHODS_FUNC__ .keys ())
5252
53+ if kwargs .get ("input_core_dims" ):
54+ if not isinstance (kwargs ["input_core_dims" ], dict ):
55+ raise TypeError ("input_core_dims must be an object of type 'dict'" )
56+ if not len (kwargs ["input_core_dims" ]) == 3 or any (
57+ not isinstance (value , str ) for value in kwargs ["input_core_dims" ].values ()
58+ ):
59+ raise ValueError (
60+ "input_core_dims must have three key-value pairs like: "
61+ '{"obs": "time", "simh": "time", "simp": "time"}' ,
62+ )
63+
64+ input_core_dims = kwargs ["input_core_dims" ]
65+ else :
66+ input_core_dims = {"obs" : "time" , "simh" : "time" , "simp" : "time" }
67+
5368 result : XRData = xr .apply_ufunc (
5469 __METHODS_FUNC__ [method ],
5570 obs ,
5671 simh ,
5772 # Need to spoof a fake time axis since 'time' coord on full dataset is different
5873 # than 'time' coord on training dataset.
59- simp .rename ({"time" : "t2 " }),
74+ simp .rename ({input_core_dims [ "simp" ] : "__t_simp__ " }),
6075 dask = "parallelized" ,
6176 vectorize = True ,
6277 # This will vectorize over the time dimension, so will submit each grid cell
6378 # independently
64- input_core_dims = [["time" ], ["time" ], ["t2" ]],
79+ input_core_dims = [
80+ [input_core_dims ["obs" ]],
81+ [input_core_dims ["simh" ]],
82+ ["__t_simp__" ],
83+ ],
6584 # Need to denote that the final output dataset will be labeled with the
6685 # spoofed time coordinate
67- output_core_dims = [["t2 " ]],
86+ output_core_dims = [["__t_simp__ " ]],
6887 kwargs = dict (kwargs ),
6988 )
7089
7190 # Rename to proper coordinate name.
72- result = result .rename ({"t2 " : "time" })
91+ result = result .rename ({"__t_simp__ " : input_core_dims [ "simp" ] })
7392
7493 # ufunc will put the core dimension to the end (time), so want to preserve original
7594 # order where time is commonly first.
@@ -90,6 +109,14 @@ def adjust(
90109
91110 See https://python-cmethods.readthedocs.io/en/latest/src/methods.html
92111
112+
113+ The time dimension of ``obs``, ``simh`` and ``simp`` must be named ``time``.
114+
115+ If the sizes of time dimensions of the input data sets differ, you have to
116+ pass the hidden ``input_core_dims`` parameter, see
117+ https://python-cmethods.readthedocs.io/en/latest/src/getting_started.html#advanced-usage
118+ for more information.
119+
93120 :param method: Technique to apply
94121 :type method: str
95122 :param obs: The reference/observational data set
@@ -127,14 +154,30 @@ def adjust(
127154 )
128155
129156 # Grouped correction | scaling-based technique
130- group : str = kwargs ["group" ]
157+ group : str | dict [str , str ] = kwargs ["group" ]
158+ if isinstance (group , str ):
159+ # only for same sized time dimensions
160+ obs_group = group
161+ simh_group = group
162+ simp_group = group
163+ elif isinstance (group , dict ):
164+ if any (key not in {"obs" , "simh" , "simp" } for key in group ):
165+ raise ValueError (
166+ "group must either be a string like 'time' or a dict like "
167+ '{"obs": "time.month", "simh": "t_simh.month", "simp": "time.month"}' ,
168+ )
169+ # for different sized time dimensions
170+ obs_group = group ["obs" ]
171+ simh_group = group ["simh" ]
172+ simp_group = group ["simp" ]
173+
131174 del kwargs ["group" ]
132175
133176 result : Optional [XRData ] = None
134177 for (_ , obs_gds ), (_ , simh_gds ), (_ , simp_gds ) in zip (
135- obs .groupby (group ),
136- simh .groupby (group ),
137- simp .groupby (group ),
178+ obs .groupby (obs_group ),
179+ simh .groupby (simh_group ),
180+ simp .groupby (simp_group ),
138181 ):
139182 monthly_result = apply_ufunc (
140183 method ,
0 commit comments