|
| 1 | +#ifndef GEMM_COMMON_C |
| 2 | +#define GEMM_COMMON_C |
| 3 | +#include "common.h" |
| 4 | + |
| 5 | +#include <altivec.h> |
| 6 | +#include <inttypes.h> |
| 7 | + |
| 8 | +#define NBMAX 4096 |
| 9 | + |
| 10 | +#define FORCEINLINE inline __attribute__((always_inline)) |
| 11 | + |
| 12 | +#ifdef _ARCH_PWR10 |
| 13 | +#ifdef __has_builtin |
| 14 | +#if !__has_builtin(__builtin_vsx_assemble_pair) |
| 15 | +#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair |
| 16 | +#endif |
| 17 | +#if !__has_builtin(__builtin_vsx_disassemble_pair) |
| 18 | +#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair |
| 19 | +#endif |
| 20 | +#endif |
| 21 | + |
| 22 | +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ |
| 23 | +#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v1, v0) |
| 24 | +#else |
| 25 | +#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v0, v1) |
| 26 | +#endif |
| 27 | + |
| 28 | +#define USE_VECTOR_PAIRS |
| 29 | +#endif |
| 30 | + |
| 31 | +typedef __vector IFLOAT vec_bf16; |
| 32 | +typedef __vector FLOAT vec_f32; |
| 33 | +typedef __vector unsigned char vec_uc8; |
| 34 | + |
| 35 | +FORCEINLINE vec_uc8 vec_load_vec(void *src) |
| 36 | +{ |
| 37 | + return vec_xl(0, (unsigned char *)(src)); |
| 38 | +} |
| 39 | + |
| 40 | +FORCEINLINE void vec_load_pair(vec_f32 *dst, vec_f32 *src) |
| 41 | +{ |
| 42 | +#ifdef USE_VECTOR_PAIRS |
| 43 | + __vector_pair vy0p; |
| 44 | +#ifdef __clang__ |
| 45 | + vy0p = __builtin_vsx_lxvp(0L, (const __vector_pair *)(src)); |
| 46 | +#else |
| 47 | + vy0p = *(__vector_pair *)(src); |
| 48 | +#endif |
| 49 | + __builtin_vsx_disassemble_pair((void *)(dst), &vy0p); |
| 50 | +#else |
| 51 | + dst[0] = src[0]; |
| 52 | + dst[1] = src[1]; |
| 53 | +#endif |
| 54 | +} |
| 55 | + |
| 56 | +FORCEINLINE void vec_store_pair(vec_f32 *dst, vec_f32 *src) |
| 57 | +{ |
| 58 | +#ifdef USE_VECTOR_PAIRS |
| 59 | + __vector_pair vy0p; |
| 60 | + __builtin_vsx_assemble_pair2(&vy0p, (vec_uc8)src[1], (vec_uc8)src[0]); |
| 61 | +#ifdef __clang__ |
| 62 | + __builtin_vsx_stxvp(vy0p, 0L, (__vector_pair *)(dst)); |
| 63 | +#else |
| 64 | + *(__vector_pair *)(dst) = vy0p; |
| 65 | +#endif |
| 66 | +#else |
| 67 | + dst[0] = src[0]; |
| 68 | + dst[1] = src[1]; |
| 69 | +#endif |
| 70 | +} |
| 71 | + |
| 72 | +FORCEINLINE vec_bf16 vec_loadN(void *src, BLASLONG n) |
| 73 | +{ |
| 74 | + IFLOAT *src2 = (IFLOAT *)(src); |
| 75 | +#ifdef _ARCH_PWR9 |
| 76 | + return vec_xl_len(src2, n * sizeof(IFLOAT)); |
| 77 | +#else |
| 78 | + __attribute__((aligned(16))) IFLOAT data[sizeof(vec_bf16) / sizeof(IFLOAT)]; |
| 79 | + memset(data, 0, sizeof(vec_bf16)); |
| 80 | + if (n & 4) { |
| 81 | + memcpy(data, src2, sizeof(uint64_t)); |
| 82 | + } |
| 83 | + if (n & 2) { |
| 84 | + BLASLONG n4 = n & 4; |
| 85 | + memcpy(data + n4, src2 + n4, sizeof(uint32_t)); |
| 86 | + } |
| 87 | + if (n & 1) { |
| 88 | + BLASLONG n6 = n & 6; |
| 89 | + data[n6] = src2[n6]; |
| 90 | + } |
| 91 | + return (vec_bf16)vec_load_vec(data); |
| 92 | +#endif |
| 93 | +} |
| 94 | + |
| 95 | +FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n) |
| 96 | +{ |
| 97 | +#ifndef _ARCH_PWR9 |
| 98 | + if (n & 4) { |
| 99 | + return (vec_f32)vec_load_vec(src); |
| 100 | + } |
| 101 | +#endif |
| 102 | + return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT))); |
| 103 | +} |
| 104 | + |
| 105 | +FORCEINLINE void vec_loadN2_f32(vec_f32 *data, vec_f32 *src, BLASLONG n) |
| 106 | +{ |
| 107 | + data[0] = src[0]; |
| 108 | + data[1] = vec_loadN_f32(&src[1], n); |
| 109 | +} |
| 110 | + |
| 111 | +FORCEINLINE void vec_storeN(vec_bf16 data, void *dst, BLASLONG n) |
| 112 | +{ |
| 113 | + IFLOAT *dst2 = (IFLOAT *)(dst); |
| 114 | +#ifdef _ARCH_PWR9 |
| 115 | + vec_xst_len(data, dst2, n * sizeof(IFLOAT)); |
| 116 | +#else |
| 117 | + if (n & 8) { |
| 118 | + vec_xst(data, 0, dst2); |
| 119 | + return; |
| 120 | + } |
| 121 | + __attribute__((aligned(16))) IFLOAT data2[sizeof(vec_f32) / sizeof(IFLOAT)]; |
| 122 | + vec_xst(data, 0, data2); |
| 123 | + if (n & 4) { |
| 124 | + memcpy(dst2, data2, sizeof(uint64_t)); |
| 125 | + } |
| 126 | + if (n & 2) { |
| 127 | + BLASLONG n4 = n & 4; |
| 128 | + memcpy(dst2 + n4, data2 + n4, sizeof(uint32_t)); |
| 129 | + } |
| 130 | + if (n & 1) { |
| 131 | + BLASLONG n6 = n & 6; |
| 132 | + dst2[n6] = data2[n6]; |
| 133 | + } |
| 134 | +#endif |
| 135 | +} |
| 136 | + |
| 137 | +FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n) |
| 138 | +{ |
| 139 | +#ifndef _ARCH_PWR9 |
| 140 | + if (n & 4) { |
| 141 | + vec_xst(data, 0, (FLOAT *)dst); |
| 142 | + return; |
| 143 | + } |
| 144 | +#endif |
| 145 | + return vec_storeN((vec_bf16)data, dst, n * (sizeof(FLOAT) / sizeof(IFLOAT))); |
| 146 | +} |
| 147 | + |
| 148 | +FORCEINLINE void vec_storeN2_f32(vec_f32 *data, vec_f32 *dst, BLASLONG n) |
| 149 | +{ |
| 150 | + dst[0] = data[0]; |
| 151 | + vec_storeN_f32(data[1], &dst[1], n); |
| 152 | +} |
| 153 | +#endif |
0 commit comments