Skip to content

Commit 84397cb

Browse files
committed
Bug fixes for linalg module
1 parent e939d9d commit 84397cb

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

geomdl/linalg.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import os
1111
import math
12-
from operator import add, sub, mul, truediv, pow
12+
from operator import add, sub, mul, truediv
1313
from copy import deepcopy
1414
from functools import reduce
1515
try:
@@ -499,7 +499,7 @@ def _doolittle(ma, dt):
499499
return ml, mu
500500

501501
# Data type, e.g. float, Decimal, etc.
502-
dtype = kwargs.get('dtype', matrix_a[0][0])
502+
dtype = kwargs.get('dtype', type(matrix_a[0][0]))
503503

504504
# Check if the 2-dimensional input matrix is a square matrix
505505
q = len(matrix_a)
@@ -525,7 +525,7 @@ def forward_substitution(matrix_l, matrix_b, **kwargs):
525525
:return: y, column matrix
526526
:rtype: list
527527
"""
528-
dtype = kwargs.get('dtype', matrix_l[0][0])
528+
dtype = kwargs.get('dtype', type(matrix_l[0][0]))
529529
q = len(matrix_b)
530530
matrix_y = [dtype(0.0) for _ in range(q)]
531531
matrix_y[0] = truediv(matrix_b[0], matrix_l[0][0])
@@ -548,7 +548,7 @@ def backward_substitution(matrix_u, matrix_y, **kwargs):
548548
:return: x, column matrix
549549
:rtype: list
550550
"""
551-
dtype = kwargs.get('dtype', matrix_u[0][0])
551+
dtype = kwargs.get('dtype', type(matrix_u[0][0]))
552552
q = len(matrix_y)
553553
matrix_x = [dtype(0.0) for _ in range(q)]
554554
matrix_x[q - 1] = truediv(matrix_y[q - 1], matrix_u[q - 1][q - 1])
@@ -574,7 +574,7 @@ def lu_solve(matrix_a, b, **kwargs):
574574
:rtype: list
575575
"""
576576
# Data type, e.g. float, Decimal, etc.
577-
dtype = kwargs.get('dtype', matrix_a[0][0])
577+
dtype = kwargs.get('dtype', type(matrix_a[0][0]))
578578
# Variable initialization
579579
dim = len(b[0])
580580
num_x = len(b)
@@ -611,7 +611,7 @@ def lu_factor(matrix_a, b, **kwargs):
611611
:rtype: list
612612
"""
613613
# Data type, e.g. float, Decimal, etc.
614-
dtype = kwargs.get('dtype', matrix_a[0][0])
614+
dtype = kwargs.get('dtype', type(matrix_a[0][0]))
615615
# Variable initialization
616616
dim = len(b[0])
617617
num_x = len(b)

0 commit comments

Comments
 (0)