Skip to content

Commit 887f4f3

Browse files
authored
Merge branch 'OpenMathLib:develop' into issue5497
2 parents 19be504 + b6d5057 commit 887f4f3

File tree

14 files changed

+382
-16
lines changed

14 files changed

+382
-16
lines changed

.github/workflows/riscv64_vector.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
opts: TARGET=RISCV64_ZVL128B BINARY=64 ARCH=riscv64
2727
qemu_cpu: rv64,g=true,c=true,v=true,vext_spec=v1.0,vlen=128,elen=64
2828
- target: RISCV64_ZVL256B
29-
opts: TARGET=RISCV64_ZVL256B BINARY=64 ARCH=riscv64
29+
opts: TARGET=RISCV64_ZVL256B BINARY=64 ARCH=riscv64 BUILD_BFLOAT16=1 BUILD_HFLOAT16=1
3030
qemu_cpu: rv64,g=true,c=true,v=true,vext_spec=v1.0,vlen=256,elen=64
3131
- target: DYNAMIC_ARCH=1
3232
opts: TARGET=RISCV64_GENERIC BINARY=64 ARCH=riscv64 DYNAMIC_ARCH=1
@@ -40,7 +40,7 @@ jobs:
4040
run: |
4141
sudo apt-get update
4242
sudo apt-get install autoconf automake autotools-dev ninja-build make \
43-
libgomp1-riscv64-cross ccache
43+
libgomp1-riscv64-cross ccache qemu-kvm
4444
wget ${riscv_gnu_toolchain}/${riscv_gnu_toolchain_nightly_download_path}
4545
tar -xvf $(basename ${riscv_gnu_toolchain_nightly_download_path}) -C /opt
4646

cmake/cc.cmake

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ endif ()
213213

214214
if (${CORE} STREQUAL A64FX)
215215
if (NOT DYNAMIC_ARCH)
216-
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
216+
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
217217
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=a64fx")
218218
elseif (${GCC_VERSION} VERSION_GREATER 11.0 OR ${GCC_VERSION} VERSION_EQUAL 11.0)
219219
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8.2-a+sve -mtune=a64fx")
@@ -227,7 +227,7 @@ if (${CORE} STREQUAL NEOVERSEV2)
227227
if (NOT DYNAMIC_ARCH)
228228
if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE)
229229
set (CCOMMON_OPT "${CCOMMON_OPT} -Msve_intrinsics -march=armv8.5-a+sve+sve2+bf16 -mtune=neoverse-v2")
230-
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
230+
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
231231
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=neoverse-v2")
232232
else ()
233233
if (${GCC_VERSION} VERSION_GREATER 13.0 OR ${GCC_VERSION} VERSION_EQUAL 13.0)
@@ -245,7 +245,7 @@ if (${CORE} STREQUAL NEOVERSEN2)
245245
if (NOT DYNAMIC_ARCH)
246246
if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE)
247247
set (CCOMMON_OPT "${CCOMMON_OPT} -Msve_intrinsics -march=armv8.5-a+sve+sve2+bf16 -mtune=neoverse-n2")
248-
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
248+
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
249249
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=neoverse-v2")
250250
else ()
251251
if (${GCC_VERSION} VERSION_GREATER 11.1 OR ${GCC_VERSION} VERSION_EQUAL 11.1)
@@ -261,7 +261,7 @@ if (${CORE} STREQUAL NEOVERSEV1)
261261
if (NOT DYNAMIC_ARCH)
262262
if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE)
263263
set (CCOMMON_OPT "${CCOMMON_OPT} -Msve_intrinsics -march=armv8.4-a+sve+bf16 -mtune=neoverse-v1")
264-
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
264+
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
265265
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=neoverse-v1")
266266
else ()
267267
if (${GCC_VERSION} VERSION_GREATER 10.4 OR ${GCC_VERSION} VERSION_EQUAL 10.4)
@@ -275,7 +275,7 @@ endif ()
275275

276276
if (${CORE} STREQUAL NEOVERSEN1)
277277
if (NOT DYNAMIC_ARCH)
278-
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
278+
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
279279
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=neoverse-n1")
280280
elseif (${GCC_VERSION} VERSION_GREATER 9.4 OR ${GCC_VERSION} VERSION_EQUAL 9.4)
281281
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8.2-a -mtune=neoverse-n1")
@@ -287,7 +287,7 @@ endif ()
287287

288288
if (${CORE} STREQUAL AMPEREONE)
289289
if (NOT DYNAMIC_ARCH)
290-
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVC")
290+
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC")
291291
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=neoverse-n1")
292292
elseif (${GCC_VERSION} VERSION_GREATER 12.1)
293293
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8.6-a+crypto+crc+fp16+sha3+rng -mtune=ampereone")
@@ -301,7 +301,7 @@ if (${CORE} STREQUAL ARMV8SVE)
301301
if (NOT DYNAMIC_ARCH)
302302
if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE)
303303
set (CCOMMON_OPT "${CCOMMON_OPT} -Msve_intrinsics -march=armv8-a+sve")
304-
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
304+
elseif (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
305305
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=host")
306306
else ()
307307
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8-a+sve")
@@ -311,7 +311,7 @@ endif ()
311311

312312
if (${CORE} STREQUAL ARMV9SME)
313313
if (NOT DYNAMIC_ARCH)
314-
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVC" AND NOT NO_SVE)
314+
if (${CMAKE_C_COMPILER_ID} STREQUAL "NVHPC" AND NOT NO_SVE)
315315
set (CCOMMON_OPT "${CCOMMON_OPT} -tp=host")
316316
else ()
317317
set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv9-a+sme")

common_level3.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,23 @@ void ssymm_direct_alpha_betaLL(BLASLONG M, BLASLONG N,
7272
float beta,
7373
float * R, BLASLONG strideR);
7474

75+
void strmm_direct_LNUN(BLASLONG M, BLASLONG N,
76+
float alpha,
77+
float * A, BLASLONG strideA,
78+
float * B, BLASLONG strideB);
79+
void strmm_direct_LNLN(BLASLONG M, BLASLONG N,
80+
float alpha,
81+
float * A, BLASLONG strideA,
82+
float * B, BLASLONG strideB);
83+
void strmm_direct_LTUN(BLASLONG M, BLASLONG N,
84+
float alpha,
85+
float * A, BLASLONG strideA,
86+
float * B, BLASLONG strideB);
87+
void strmm_direct_LTLN(BLASLONG M, BLASLONG N,
88+
float alpha,
89+
float * A, BLASLONG strideA,
90+
float * B, BLASLONG strideB);
91+
7592
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
7693

7794
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,

common_param.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ int (*shgemv_t) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BL
260260
void (*sgemm_direct_alpha_beta) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
261261
void (*ssymm_direct_alpha_betaLU) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
262262
void (*ssymm_direct_alpha_betaLL) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float, float * , BLASLONG);
263+
void (*strmm_direct_LNUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
264+
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
265+
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
266+
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
263267
#endif
264268

265269

common_s.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
#define SGEMM_DIRECT_ALPHA_BETA sgemm_direct_alpha_beta
5353
#define SSYMM_DIRECT_ALPHA_BETA_LU ssymm_direct_alpha_betaLU
5454
#define SSYMM_DIRECT_ALPHA_BETA_LL ssymm_direct_alpha_betaLL
55+
#define STRMM_DIRECT_LNUN strmm_direct_LNUN
56+
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
57+
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
58+
#define STRMM_DIRECT_LTLN strmm_direct_LTLN
5559

5660
#define SGEMM_ONCOPY sgemm_oncopy
5761
#define SGEMM_OTCOPY sgemm_otcopy
@@ -224,6 +228,10 @@
224228
#define SGEMM_DIRECT_ALPHA_BETA gotoblas -> sgemm_direct_alpha_beta
225229
#define SSYMM_DIRECT_ALPHA_BETA_LU gotoblas -> ssymm_direct_alpha_betaLU
226230
#define SSYMM_DIRECT_ALPHA_BETA_LL gotoblas -> ssymm_direct_alpha_betaLL
231+
#define STRMM_DIRECT_LNUN gotoblas -> strmm_direct_LNUN
232+
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
233+
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
234+
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
227235
#endif
228236

229237
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy

getarch.c

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,10 +2046,9 @@ int main(int argc, char *argv[]){
20462046
#endif
20472047

20482048

2049-
#ifdef INTEL_AMD
2050-
#ifndef FORCE
2049+
#if defined(INTEL_AMD) && !defined(FORCE)
20512050
get_sse();
2052-
#else
2051+
#elif defined(FORCE_INTEL)
20532052

20542053
sprintf(buffer, "%s", ARCHCONFIG);
20552054

@@ -2079,7 +2078,6 @@ int main(int argc, char *argv[]){
20792078
} else p ++;
20802079
}
20812080
#endif
2082-
#endif
20832081

20842082
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
20852083
printf("__BYTE_ORDER__=__ORDER_BIG_ENDIAN__\n");

interface/trsm.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,23 @@ void CNAME(enum CBLAS_ORDER order,
355355
return;
356356
}
357357

358+
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
359+
#if defined(ARCH_ARM64) && (defined(USE_STRMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
360+
#if defined(DYNAMIC_ARCH)
361+
if (support_sme1())
362+
#endif
363+
if (args.m == 0 || args.n == 0) return;
364+
if (order == CblasRowMajor && Diag == CblasNonUnit && Side == CblasLeft && m == lda && n == ldb) {
365+
if (Trans == CblasNoTrans) {
366+
(Uplo == CblasUpper ? STRMM_DIRECT_LNUN : STRMM_DIRECT_LNLN)(m, n, alpha, a, lda, b, ldb);
367+
} else if (Trans == CblasTrans) {
368+
(Uplo == CblasUpper ? STRMM_DIRECT_LTUN : STRMM_DIRECT_LTLN)(m, n, alpha, a, lda, b, ldb);
369+
}
370+
return;
371+
}
372+
#endif
373+
#endif
374+
358375
#endif
359376

360377
if ((args.m == 0) || (args.n == 0)) return;

kernel/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
241241
if (ZARCH OR (UC_TARGET_CORE MATCHES POWER8) OR (UC_TARGET_CORE MATCHES POWER9) OR (UC_TARGET_CORE MATCHES POWER10))
242242
set(USE_TRMM true)
243243
endif ()
244+
set(USE_DIRECT_STRMM false)
245+
if (ARM64)
246+
set(USE_DIRECT_STRMM true)
247+
endif()
244248
set(USE_DIRECT_SGEMM false)
245249
if (X86_64 OR ARM64)
246250
set(USE_DIRECT_SGEMM true)
@@ -283,6 +287,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
283287
endif ()
284288
endif()
285289

290+
if (USE_DIRECT_STRMM)
291+
if (ARM64)
292+
set (STRMMDIRECTKERNEL strmm_direct_arm64_sme1.c)
293+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNUN" false "" "" false SINGLE)
294+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNLN" false "" "" false SINGLE)
295+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTUN" false "" "" false SINGLE)
296+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTLN" false "" "" false SINGLE)
297+
endif ()
298+
endif ()
299+
286300
foreach (float_type SINGLE DOUBLE)
287301
string(SUBSTRING ${float_type} 0 1 float_char)
288302
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})
@@ -458,6 +472,7 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
458472
set(TRMM_KERNEL "${${float_char}GEMMKERNEL}")
459473
endif ()
460474

475+
461476
if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX")
462477

463478
# just enumerate all these. there is an extra define for these indicating which side is a conjugate (e.g. CN NC NN) that I don't really want to work into GenerateCombinationObjects

kernel/Makefile.L3

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ ifeq ($(ARCH), arm64)
5353
USE_TRMM = 1
5454
USE_DIRECT_SGEMM = 1
5555
USE_DIRECT_SSYMM = 1
56+
USE_DIRECT_STRMM = 1
5657
endif
5758

5859
ifeq ($(ARCH), riscv64)
@@ -149,6 +150,18 @@ endif
149150
endif
150151
endif
151152

153+
ifdef USE_DIRECT_STRMM
154+
ifndef STRMMDIRECTKERNEL
155+
ifeq ($(ARCH), arm64)
156+
ifeq ($(TARGET_CORE), ARMV9SME)
157+
HAVE_SME = 1
158+
endif
159+
STRMMDIRECTKERNEL = strmm_direct_arm64_sme1.c
160+
endif
161+
endif
162+
endif
163+
164+
152165
ifeq ($(BUILD_BFLOAT16), 1)
153166
ifndef BGEMMKERNEL
154167
BGEMM_BETA = ../generic/gemm_beta.c
@@ -240,6 +253,14 @@ SKERNELOBJS += \
240253
endif
241254
endif
242255

256+
ifdef USE_DIRECT_STRMM
257+
ifeq ($(ARCH), arm64)
258+
SKERNELOBJS += \
259+
strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) \
260+
strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) strmm_direct_LTLN$(TSUFFIX).$(SUFFIX)
261+
endif
262+
endif
263+
243264
ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
244265
DKERNELOBJS += \
245266
dgemm_beta$(TSUFFIX).$(SUFFIX) \
@@ -1179,6 +1200,23 @@ else
11791200
$(CC) $(CFLAGS) -c -DTRMMKERNEL -UDOUBLE -UCOMPLEX -ULEFT -DTRANSA $< -o $@
11801201
endif
11811202

1203+
1204+
ifdef USE_DIRECT_STRMM
1205+
ifeq ($(ARCH), arm64)
1206+
$(KDIR)strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1207+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -DUPPER $< -o $@
1208+
1209+
$(KDIR)strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1210+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -UUPPER $< -o $@
1211+
1212+
$(KDIR)strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1213+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -DUPPER $< -o $@
1214+
1215+
$(KDIR)strmm_direct_LTLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1216+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -UUPPER $< -o $@
1217+
endif
1218+
endif
1219+
11821220
$(KDIR)dtrmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DTRMMKERNEL)
11831221
ifeq ($(OS), AIX)
11841222
$(CC) $(CFLAGS) -S -DTRMMKERNEL -DDOUBLE -UCOMPLEX -DLEFT -UTRANSA $< -o - > dtrmm_kernel_ln.s

kernel/arm64/ssymm_direct_alpha_beta_arm64_sme1.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,11 @@ static void ssymm_direct_sme1_preprocessLL(uint64_t nbr, uint64_t nbc,
189189
}
190190
}
191191
}
192-
192+
#else
193+
static void ssymm_direct_sme1_preprocessLU(uint64_t nbr, uint64_t nbc,
194+
const float *restrict a, float *restrict a_mod){}
195+
static void ssymm_direct_sme1_preprocessLL(uint64_t nbr, uint64_t nbc,
196+
const float *restrict a, float *restrict a_mod){}
193197
#endif
194198

195199
//

0 commit comments

Comments
 (0)