22#define _NBL_BUILTIN_HLSL_EMULATED_MATRIX_T_HLSL_INCLUDED_
33
44#include <nbl/builtin/hlsl/portable/float64_t.hlsl>
5+ #include <nbl/builtin/hlsl/emulated/vector_t.hlsl>
56#include <nbl/builtin/hlsl/matrix_utils/matrix_traits.hlsl>
67
78namespace nbl
@@ -63,9 +64,14 @@ struct matrix_traits<emulated_matrix<T, ROW_COUNT, COLUMN_COUNT> > \
6364};
6465
6566DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (2 , 2 )
67+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (2 , 3 )
68+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (2 , 4 )
69+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (3 , 2 )
6670DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (3 , 3 )
67- DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (4 , 4 )
6871DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (3 , 4 )
72+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (4 , 2 )
73+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (4 , 3 )
74+ DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION (4 , 4 )
6975
7076#undef DEFINE_MATRIX_TRAITS_TEMPLATE_SPECIALIZATION
7177
@@ -91,18 +97,18 @@ struct mul_helper<emulated_matrix<ComponentT, N, M>, emulated_matrix<ComponentT,
9197
9298 static inline return_t __call (LhsT lhs, RhsT rhs)
9399 {
94- typename matrix_traits<RhsT>::transposed_type rhsTransposed = rhs.getTransposed ();
95- const uint32_t outputRowCount = matrix_traits<return_t>::RowCount;
96- const uint32_t outputColumnCount = matrix_traits<return_t>::ColumnCount;
97100 using OutputVecType = typename matrix_traits<return_t>::row_type;
101+ const uint32_t outputRowCount = vector_traits<OutputVecType>::Dimension;
98102
99- nbl::hlsl::array_set<OutputVecType , typename vector_traits<OutputVecType >::scalar_type> setter ;
103+ nbl::hlsl::array_get<typename matrix_traits<LhsT>::row_type , typename vector_traits<typename matrix_traits<LhsT >::row_type>:: scalar_type> getter ;
100104
101105 return_t output;
102- for (int r = 0 ; r < outputRowCount; ++r)
106+ const uint32_t RHSRowCount = matrix_traits<RhsT>::RowCount;
107+ for (uint32_t rO = 0 ; rO < outputRowCount; ++rO)
103108 {
104- for (int c = 0 ; c < outputColumnCount; ++c)
105- setter (output.rows[r], c, dot<OutputVecType>(lhs.rows[r], rhsTransposed.rows[c]));
109+ output.rows[rO] = rhs.rows[0 ] * getter (lhs.rows[rO], 0 );
110+ for (uint32_t rI = 1 ; rI < RHSRowCount; ++rI) // its also the LHS column count
111+ output.rows[rO] = output.rows[rO] + rhs.rows[rI] * getter (lhs.rows[rO], rI);
106112 }
107113
108114 return output;
@@ -132,4 +138,4 @@ struct mul_helper<emulated_matrix<ComponentT, RowCount, ColumnCount>, emulated_v
132138
133139}
134140}
135- #endif
141+ #endif
0 commit comments