@@ -791,10 +791,10 @@ def select(condition, input, other):
791791 return pyboost .select_op (condition , input , other )
792792 return legacy .select (condition , input , other )
793793
794- def mean (input , axis , keepdims , dtype ):
794+ def mean (input , dim , keepdim , dtype ):
795795 if use_pyboost ():
796- return pyboost .mean_ext_op (input , axis , keepdims , dtype )
797- return legacy .reduce_mean (input , axis , keepdims )
796+ return pyboost .mean_ext_op (input , dim , keepdim , dtype )
797+ return legacy .reduce_mean (input , dim , keepdim )
798798
799799def index (input , index ):
800800 if use_pyboost ():
@@ -1552,9 +1552,29 @@ def one_hot(tensor, num_classes):
15521552 return legacy .one_hot (tensor , num_classes , on_value , off_value , - 1 )
15531553
15541554def var (input , dim = None , correction = 1 , keepdim = False ):
1555- if use_pyboost ():
1555+ if use_pyboost () and not ON_ORANGE_PI :
15561556 return pyboost .var_op (input , dim , correction , keepdim )
1557- return legacy .var (input , dim , correction , keepdim )
1557+ if dim is None :
1558+ input_mean = mean (input , (), False , None )
1559+ else :
1560+ input_mean = mean (input , dim = dim , keepdim = True , dtype = None )
1561+
1562+ # 计算与均值的平方差
1563+ squared_diff = pow (sub (input , input_mean , 1 ), 2 )
1564+ # 计算方差
1565+ if dim is None :
1566+ variance = mean (squared_diff , (), False , None )
1567+ n = input .numel () # 总元素个数
1568+ else :
1569+ variance = mean (squared_diff , dim = dim , keepdim = keepdim , dtype = None )
1570+ n = input .size (dim ) # 指定维度的元素个数
1571+
1572+ # 无偏估计校正
1573+ if correction and n > 1 :
1574+ variance = mul (variance , (n / (n - 1 )))
1575+
1576+ return variance
1577+
15581578
15591579def linspace (start , end , steps , dtype = None ):
15601580 if use_pyboost () and not ON_ORANGE_PI :
0 commit comments