@@ -265,27 +265,6 @@ def test_decimal_with_when(self):
265265 Decimal ('60' ), Decimal ('55000' ), Decimal ('0' ),
266266 'end' ))
267267
268- # begin
269- assert_equal (npf .ipmt (Decimal ('0.1' ) / Decimal ('12' ), Decimal ('1' ),
270- Decimal ('24' ), Decimal ('2000' ),
271- Decimal ('0' ), Decimal ('1' )).flat [0 ],
272- npf .ipmt (Decimal ('0.1' ) / Decimal ('12' ), Decimal ('1' ),
273- Decimal ('24' ), Decimal ('2000' ),
274- Decimal ('0' ), 'begin' ).flat [0 ])
275- # end
276- assert_equal (npf .ipmt (Decimal ('0.1' ) / Decimal ('12' ), Decimal ('1' ),
277- Decimal ('24' ), Decimal ('2000' ),
278- Decimal ('0' )).flat [0 ],
279- npf .ipmt (Decimal ('0.1' ) / Decimal ('12' ), Decimal ('1' ),
280- Decimal ('24' ), Decimal ('2000' ),
281- Decimal ('0' ), 'end' ).flat [0 ])
282- assert_equal (npf .ipmt (Decimal ('0.1' ) / Decimal ('12' ), Decimal ('1' ),
283- Decimal ('24' ), Decimal ('2000' ),
284- Decimal ('0' ), Decimal ('0' )).flat [0 ],
285- npf .ipmt (Decimal ('0.1' ) / Decimal ('12' ), Decimal ('1' ),
286- Decimal ('24' ), Decimal ('2000' ),
287- Decimal ('0' ), 'end' ).flat [0 ])
288-
289268 def test_broadcast (self ):
290269 assert_almost_equal (npf .nper (0.075 , - 2000 , 0 , 100000. , [0 , 1 ]),
291270 [21.5449442 , 20.76156441 ], 4 )
@@ -374,6 +353,33 @@ def test_when_is_end(self, when):
374353 result = npf .ipmt (0.1 / 12 , 1 , 24 , 2000 , 0 , when )
375354 assert_allclose (result , - 16.666667 , rtol = 1e-6 )
376355
356+
357+ @pytest .mark .parametrize ('when' , [Decimal ('1' ), 'begin' ])
358+ def test_when_is_begin_decimal (self , when ):
359+ result = npf .ipmt (
360+ Decimal ('0.1' ) / Decimal ('12' ),
361+ Decimal ('1' ),
362+ Decimal ('24' ),
363+ Decimal ('2000' ),
364+ Decimal ('0' ),
365+ when ,
366+ )
367+ assert result == 0
368+
369+ @pytest .mark .parametrize ('when' , [None , Decimal ('0' ), 'end' ])
370+ def test_when_is_end_decimal (self , when ):
371+ # Computed using Google Sheet's IPMT
372+ desired = Decimal ('-16.666667' )
373+ args = (
374+ Decimal ('0.1' ) / Decimal ('12' ),
375+ Decimal ('1' ),
376+ Decimal ('24' ),
377+ Decimal ('2000' ),
378+ Decimal ('0' )
379+ )
380+ result = npf .ipmt (* args ) if when is None else npf .ipmt (* args , when )
381+ assert_almost_equal (result , desired , decimal = 5 )
382+
377383 @pytest .mark .parametrize ('per, desired' , [
378384 (0 , numpy .nan ),
379385 (1 , 0 ),
@@ -418,6 +424,14 @@ def test_decimal_broadcasting(self):
418424 )
419425 assert_almost_equal (result , desired , decimal = 4 )
420426
427+ def test_0d_inputs (self ):
428+ args = (0.1 / 12 , 1 , 24 , 2000 )
429+ # Scalar inputs should return a scalar.
430+ assert numpy .isscalar (npf .ipmt (* args ))
431+ args = (numpy .array (args [0 ]),) + args [1 :]
432+ # 0d array inputs should return a scalar.
433+ assert numpy .isscalar (npf .ipmt (* args ))
434+
421435
422436class TestFv :
423437 def test_float (self ):
0 commit comments