Skip to content

Commit a5413d7

Browse files
committed
init branch
1 parent 169d4e1 commit a5413d7

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

examples/webgpu_compute_reduce.html

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ <h3 id="panel-title" style="flex: 0 0 auto;">Subgroup Reduction Explanation</h3>
190190
<script type="module">
191191

192192
import * as THREE from 'three/webgpu';
193-
import { instancedArray, Loop, If, vec3, dot, clamp, storage, uvec4, subgroupAdd, uniform, uv, uint, float, Fn, vec2, invocationLocalIndex, invocationSubgroupIndex, uvec2, floor, instanceIndex, workgroupId, workgroupBarrier, workgroupArray, subgroupSize, select, log2 } from 'three/tsl';
193+
import { instancedArray, Loop, If, vec3, dot, clamp, storage, uvec4, subgroupAdd, uniform, uv, uint, float, Fn, vec2, invocationLocalIndex, invocationSubgroupIndex, uvec2, floor, instanceIndex, workgroupId, workgroupBarrier, workgroupArray, subgroupSize, select, countTrailingZeros } from 'three/tsl';
194194

195195
import WebGPU from 'three/addons/capabilities/WebGPU.js';
196196

@@ -831,12 +831,12 @@ <h3 id="panel-title" style="flex: 0 0 auto;">Subgroup Reduction Explanation</h3>
831831

832832
// Multiple approaches here
833833
// log2(subgroupSize) -> TSL log2 function
834-
// countTrailingZeros/findLSB(subgroupSize) -> Currently unsupported function in TSL that counts trailing zeros in number bit representation
834+
// countTrailingZeros/findLSB(subgroupSize) -> TSL function that counts trailing zeros in number bit representation
835835
// Can technically petition GPU for subgroupSize in shader and calculate logs on CPU at cost of shader being generalizable across devices
836836
// May also break if subgroupSize changes when device is lost or if program is rerun on lower power device
837-
const subgroupSizeLog = uint( log2( float( subgroupSize ) ) ).toVar( 'subgroupSizeLog' );
837+
const subgroupSizeLog = countTrailingZeros( subgroupSize ).toVar( 'subgroupSizeLog' );
838838
const spineSize = uint( workgroupSize ).shiftRight( subgroupSizeLog );
839-
const spineSizeLog = uint( log2( float( spineSize ) ) ).toVar( 'spineSizeLog' );
839+
const spineSizeLog = countTrailingZeros( spineSize ).toVar( 'spineSizeLog' );
840840

841841

842842
// Align size to powers of subgroupSize

src/Three.TSL.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ export const batch = TSL.batch;
8080
export const bentNormalView = TSL.bentNormalView;
8181
export const billboarding = TSL.billboarding;
8282
export const bitAnd = TSL.bitAnd;
83+
export const bitCount = TSL.bitCount;
8384
export const bitNot = TSL.bitNot;
8485
export const bitOr = TSL.bitOr;
8586
export const bitXor = TSL.bitXor;
@@ -136,6 +137,9 @@ export const context = TSL.context;
136137
export const convert = TSL.convert;
137138
export const convertColorSpace = TSL.convertColorSpace;
138139
export const convertToTexture = TSL.convertToTexture;
140+
export const countLeadingZeros = TSL.countLeadingZeros;
141+
export const countOneBits = TSL.countOneBits;
142+
export const countTrailingZeros = TSL.countTrailingZeros;
139143
export const cos = TSL.cos;
140144
export const cross = TSL.cross;
141145
export const cubeTexture = TSL.cubeTexture;
@@ -180,6 +184,8 @@ export const expression = TSL.expression;
180184
export const faceDirection = TSL.faceDirection;
181185
export const faceForward = TSL.faceForward;
182186
export const faceforward = TSL.faceforward;
187+
export const findLSB = TSL.findLSB;
188+
export const findMSB = TSL.findMSB;
183189
export const float = TSL.float;
184190
export const floatBitsToInt = TSL.floatBitsToInt;
185191
export const floatBitsToUint = TSL.floatBitsToUint;

src/nodes/math/MathNode.js

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ MathNode.FWIDTH = 'fwidth';
364364
MathNode.TRANSPOSE = 'transpose';
365365
MathNode.DETERMINANT = 'determinant';
366366
MathNode.INVERSE = 'inverse';
367+
MathNode.COUNT_TRAILING_ZEROS = 'countTrailingZeros';
368+
MathNode.COUNT_LEADING_ZEROS = 'countLeadingZeros';
369+
MathNode.COUNT_ONE_BITS = 'countOneBits';
367370

368371
// 2 inputs
369372

@@ -1099,10 +1102,50 @@ export const atan2 = ( y, x ) => { // @deprecated, r172
10991102

11001103
};
11011104

1105+
1106+
/**
1107+
* Finds the number of consecutive 0 bits from the least significant bit of the input value,
1108+
* which is also the index of the least significant bit of the input value.
1109+
*
1110+
* Can only be used with {@link WebGPURenderer} and a WebGPU backend.
1111+
*
1112+
* @tsl
1113+
* @function
1114+
* @param {Node | number} x - The input value.
1115+
* @returns {Node}
1116+
*/
1117+
export const countTrailingZeros = /*@__PURE__*/ nodeProxyIntent( MathNode, MathNode.COUNT_TRAILING_ZEROS ).setParameterLength( 1 );
1118+
1119+
/**
1120+
* Finds the number of consecutive 0 bits starting from the most significant bit of the input value.
1121+
*
1122+
* Can only be used with {@link WebGPURenderer} and a WebGPU backend.
1123+
*
1124+
* @tsl
1125+
* @function
1126+
* @param {Node | number} x - The input value.
1127+
* @returns {Node}
1128+
*/
1129+
export const countLeadingZeros = /*@__PURE__*/ nodeProxyIntent( MathNode, MathNode.COUNT_LEADING_ZEROS ).setParameterLength( 1 );
1130+
1131+
/**
1132+
* Finds the number of '1' bits set in the input value
1133+
*
1134+
* Can only be used with {@link WebGPURenderer} and a WebGPU backend.
1135+
*
1136+
* @tsl
1137+
* @function
1138+
* @returns {Node}
1139+
*/
1140+
export const countOneBits = /*@__PURE__*/ nodeProxyIntent( MathNode, MathNode.COUNT_ONE_BITS ).setParameterLength( 1 );
1141+
11021142
// GLSL alias function
11031143

11041144
export const faceforward = faceForward;
11051145
export const inversesqrt = inverseSqrt;
1146+
export const findLSB = countTrailingZeros;
1147+
export const findMSB = countLeadingZeros;
1148+
export const bitCount = countOneBits;
11061149

11071150
// Method chaining
11081151

@@ -1165,3 +1208,6 @@ addMethodChaining( 'transpose', transpose );
11651208
addMethodChaining( 'determinant', determinant );
11661209
addMethodChaining( 'inverse', inverse );
11671210
addMethodChaining( 'rand', rand );
1211+
addMethodChaining( 'countTrailingZeros', countTrailingZeros );
1212+
addMethodChaining( 'countLeadingZeros', countLeadingZeros );
1213+
addMethodChaining( 'countOneBits', countOneBits );

0 commit comments

Comments
 (0)