@@ -1106,9 +1106,6 @@ static clblasStatus gpu_dtrsm128(
11061106 if (order != clblasColumnMajor)
11071107 return clblasNotImplemented;
11081108
1109- // for now
1110- if (side == clblasRight)
1111- return clblasNotImplemented;
11121109
11131110 int inner_block_size = 16 ; // inner blocking size, <=32
11141111 int outer_block_size = 128 ;// outer blocking size, >BLOCK_SIZE
@@ -1285,8 +1282,147 @@ static clblasStatus gpu_dtrsm128(
12851282 }
12861283 else
12871284 {
1288- clReleaseMemObject (X);
1289- return clblasNotImplemented;
1285+ //
1286+ // Helper for C = alpha * B * A + beta * C
1287+ //
1288+ // In the calls below
1289+ // - the 2nd matrix shall be either A or InvA transposed according to transA
1290+ // - the 1st and 3rd matrices are either B and X
1291+ //
1292+ #define DGEMM_RIGHT (m,n,k, alpha, B, A, beta, C ) \
1293+ do { \
1294+ err = clblasDgemm (clblasColumnMajor, clblasNoTrans, transA , m, n, k, alpha, B, A, beta, C , 1 , commandQueues, 0 , NULL , events ) ; \
1295+ CL_CHECK (err); \
1296+ } while (0 )
1297+
1298+
1299+ // side=R
1300+ /* invert the diagonals
1301+ * Allocate device memory for the inverted diagonal blocks, size=n*BLOCK_SIZE
1302+ */
1303+
1304+ /* invert the diagonals
1305+ * Allocate device memory for the inverted diagonal blocks, size=m*nb
1306+ */
1307+ size_t ldInvA = outer_block_size;
1308+ size_t offInvA = 0 ; // must be 0: needed by the _(X,i,j) macro
1309+ size_t size_InvA = ldInvA * BLOCKS (N, outer_block_size) * outer_block_size *sizeof (double );
1310+ InvA = clCreateBuffer (context, CL_MEM_READ_WRITE, size_InvA, NULL , &err);
1311+ CL_CHECK (err);
1312+ err = clearBuffer (commandQueues[0 ], InvA, size_InvA);
1313+ CL_CHECK (err);
1314+
1315+ err = diag_dtrtri128 (commandQueues[0 ], N, uplo, diag, A, offA, InvA, ldA, inner_block_size, outer_block_size, events);
1316+ CL_CHECK (err);
1317+
1318+
1319+ if (transA == clblasNoTrans)
1320+ {
1321+ /* the non-transpose case */
1322+ if (uplo == clblasLower)
1323+ {
1324+ /* the lower case */
1325+ /* handle the first block seperately with alpha */
1326+
1327+ int nn = (N % outer_block_size == 0 ) ? outer_block_size : (N % outer_block_size);
1328+ i = N - nn;
1329+ DGEMM_RIGHT (M, nn, nn, alpha, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1330+
1331+ if (i - outer_block_size >= 0 )
1332+ {
1333+
1334+ DGEMM_RIGHT (M, i, nn, neg_one, _ (X, 0 , i), _ (A, i, 0 ), alpha, _ (B, 0 , 0 ));
1335+
1336+ /* the rest blocks */
1337+ for (i = N - nn - outer_block_size; i >= 0 ; i -= outer_block_size) {
1338+ DGEMM_RIGHT (M, outer_block_size, outer_block_size, one, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1339+
1340+ if (i - outer_block_size < 0 )
1341+ break ;
1342+
1343+ DGEMM_RIGHT (M, i, outer_block_size, neg_one, _ (X, 0 , i), _ (A, i, 0 ), one, _ (B, 0 , 0 ));
1344+ }
1345+ }
1346+ }
1347+ else
1348+ {
1349+ /* the upper case */
1350+ /* handle the first block seperately with alpha */
1351+ int nn = min (outer_block_size, (int )N);
1352+ DGEMM_RIGHT (M, nn, nn, alpha, _ (B, 0 , 0 ), _ (InvA, 0 , 0 ), zero, _ (X, 0 , 0 ));
1353+
1354+ if (outer_block_size < N)
1355+ {
1356+
1357+ DGEMM_RIGHT (M, N - outer_block_size, outer_block_size, neg_one, _ (X, 0 , 0 ), _ (A, 0 , outer_block_size), alpha, _ (B, 0 , outer_block_size));
1358+
1359+ /* the rest blocks */
1360+ for (i = outer_block_size; i < N; i += outer_block_size) {
1361+ nn = min (outer_block_size, (int )N - i);
1362+ DGEMM_RIGHT (M, nn, nn, one, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1363+
1364+ if (i + outer_block_size >= N)
1365+ break ;
1366+
1367+ DGEMM_RIGHT (M, N - i - outer_block_size, outer_block_size, neg_one, _ (X, 0 , i), _ (A, i, i + outer_block_size), one, _ (B, 0 , i + outer_block_size));
1368+ }
1369+ }
1370+ }
1371+ }
1372+ else
1373+ {
1374+
1375+ /* the transpose case */
1376+ if (uplo == clblasLower)
1377+ {
1378+ /* the lower case */
1379+ /* handle the first block seperately with alpha */
1380+
1381+ int nn = min (outer_block_size, (int )N);
1382+ DGEMM_RIGHT (M, nn, nn, alpha, _ (B, 0 , 0 ), _ (InvA, 0 , 0 ), zero, _ (X, 0 , 0 ));
1383+
1384+ if (outer_block_size < N)
1385+ {
1386+
1387+ DGEMM_RIGHT (M, N - outer_block_size, outer_block_size, neg_one, _ (X, 0 , 0 ), _ (A, outer_block_size, 0 ), alpha, _ (B, 0 , outer_block_size));
1388+
1389+ /* the rest blocks */
1390+ for (i = outer_block_size; i < N; i += outer_block_size) {
1391+ nn = min (outer_block_size, (int )N - i);
1392+ DGEMM_RIGHT (M, nn, nn, one, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1393+
1394+ if (i + outer_block_size >= N)
1395+ break ;
1396+
1397+ DGEMM_RIGHT (M, N - i - outer_block_size, outer_block_size, neg_one, _ (X, 0 , i), _ (A, outer_block_size + i, i), one, _ (B, 0 , i + outer_block_size));
1398+ }
1399+ }
1400+ }
1401+ else
1402+ {
1403+ /* the upper case */
1404+ /* handle the first block seperately with alpha */
1405+ int nn = (N % outer_block_size == 0 ) ? outer_block_size : (N % outer_block_size);
1406+ i = N - nn;
1407+ DGEMM_RIGHT (M, nn, nn, alpha, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1408+
1409+ if (i - outer_block_size >= 0 )
1410+ {
1411+
1412+ DGEMM_RIGHT (M, i, nn, neg_one, _ (X, 0 , i), _ (A, 0 , i), alpha, _ (B, 0 , 0 ));
1413+
1414+ /* the rest blocks */
1415+ for (i = N - nn - outer_block_size; i >= 0 ; i -= outer_block_size) {
1416+ DGEMM_RIGHT (M, outer_block_size, outer_block_size, one, _ (B, 0 , i), _ (InvA, 0 , i), zero, _ (X, 0 , i));
1417+
1418+ if (i - outer_block_size < 0 )
1419+ break ;
1420+
1421+ DGEMM_RIGHT (M, i, outer_block_size, neg_one, _ (X, 0 , i), _ (A, 0 , i), one, _ (B, 0 , 0 ));
1422+ }
1423+ }
1424+ }
1425+ }
12901426 }
12911427
12921428 // Copy X(m,n) to B(m,n)
0 commit comments