Skip to content

Commit 5afbff7

Browse files
committed
Move Tensors.UnsafeMathOptimizations and Tensors.SuppressWarnings to OpenCL
1 parent d173f6e commit 5afbff7

File tree

5 files changed

+34
-33
lines changed

5 files changed

+34
-33
lines changed

OpenCL/src/main/scala/com/thoughtworks/compute/OpenCL.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,7 @@ object OpenCL {
929929
checkBuildErrorCode(None, clBuildProgram(handle, null, options, null, NULL))
930930
}
931931

932-
def build(): Unit = build("")
932+
def build()(implicit witness: Witness.Aux[Owner]): Unit = build(witness.value.defaultProgramOptions)
933933

934934
def monadicClose = UnitContinuation.delay {
935935
OpenCL.checkErrorCode(clReleaseProgram(handle))
@@ -1113,11 +1113,34 @@ object OpenCL {
11131113
}
11141114
}
11151115

1116+
/** A plug-in of Tensors to suppress warnings during compiling a OpenCL kernel for non-AMD platforms. */
1117+
trait SuppressWarnings extends OpenCL {
1118+
@transient
1119+
private lazy val _defaultProgramOptions = {
1120+
if (platformCapabilities.cl_amd_compile_options) {
1121+
// AMD SDK does not support -w flag in OpenCL specification.
1122+
super.defaultProgramOptions
1123+
} else {
1124+
super.defaultProgramOptions + " -w"
1125+
}
1126+
}
1127+
1128+
override protected def defaultProgramOptions: CharSequence = _defaultProgramOptions
1129+
}
1130+
1131+
trait UnsafeMathOptimizations extends OpenCL {
1132+
private lazy val _defaultProgramOptions = super.defaultProgramOptions + " -cl-unsafe-math-optimizations"
1133+
1134+
abstract override protected def defaultProgramOptions: CharSequence = _defaultProgramOptions
1135+
}
1136+
11161137
}
11171138

11181139
trait OpenCL extends MonadicCloseable[UnitContinuation] with DefaultCloseable {
11191140
import OpenCL._
11201141

1142+
protected def defaultProgramOptions: CharSequence = ""
1143+
11211144
protected def createKernels(program: Program): Seq[Kernel] = {
11221145
val stack = stackPush()
11231146
try {

Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,6 @@ object Tensors {
4141

4242
private final val MaxWorkItemDimensions = 3
4343

44-
/** A plug-in of Tensors to suppress warnings during compiling a OpenCL kernel for non-AMD platforms. */
45-
trait SuppressWarnings extends Tensors {
46-
@transient
47-
private lazy val _openclCompilerFlags = {
48-
if (platformCapabilities.cl_amd_compile_options) {
49-
// AMD SDK does not support -w flag in OpenCL specification.
50-
super.openclCompilerFlags
51-
} else {
52-
super.openclCompilerFlags + " -w"
53-
}
54-
}
55-
56-
override protected def openclCompilerFlags: String = _openclCompilerFlags
57-
}
58-
59-
trait UnsafeMathOptimizations extends Tensors {
60-
private lazy val _openclCompilerFlags = super.openclCompilerFlags + " -cl-unsafe-math-optimizations"
61-
override protected def openclCompilerFlags: String = _openclCompilerFlags
62-
}
63-
6444
trait TensorBuilder[Data] {
6545
type Element
6646
def flatten(a: Data): Seq[Element]
@@ -320,8 +300,6 @@ trait Tensors extends OpenCL {
320300

321301
protected def hashSourceCode: Fastring
322302

323-
protected def openclCompilerFlags: String = ""
324-
325303
protected object PlusPrograms extends MonoidPrograms {
326304
def append(leftHandSide: Fastring, rightHandSide: Fastring): Fastring = fast"(($leftHandSide) + ($rightHandSide))"
327305
def zero: Fastring = fast"0.0f"
@@ -368,7 +346,7 @@ trait Tensors extends OpenCL {
368346
}
369347
}
370348
""")
371-
program.build(openclCompilerFlags)
349+
program.build()
372350
program
373351
}
374352

@@ -409,7 +387,7 @@ trait Tensors extends OpenCL {
409387
}
410388
}
411389
""")
412-
program.build(openclCompilerFlags)
390+
program.build()
413391
program
414392
}
415393
}
@@ -446,7 +424,7 @@ trait Tensors extends OpenCL {
446424
buffer[i * 2 + 1] = z1;
447425
}
448426
""")
449-
program.build(openclCompilerFlags)
427+
program.build()
450428
program
451429
}
452430

@@ -460,7 +438,7 @@ trait Tensors extends OpenCL {
460438
buffer[i] = hash(i ^ seed) / 4294967296.0f;
461439
}
462440
""")
463-
program.build(openclCompilerFlags)
441+
program.build()
464442
program
465443
}
466444

@@ -1049,7 +1027,7 @@ trait Tensors extends OpenCL {
10491027
/**
10501028
* @group delayed
10511029
*/
1052-
def transpose: TransformedTensor = { permute(shape.indices.reverse.toArray)}
1030+
def transpose: TransformedTensor = { permute(shape.indices.reverse.toArray) }
10531031

10541032
/**
10551033
* @group delayed
@@ -1216,7 +1194,7 @@ trait Tensors extends OpenCL {
12161194
}
12171195

12181196
val program = createProgramWithSource(sourceCode)
1219-
program.build(openclCompilerFlags)
1197+
program.build()
12201198

12211199
val compiledKernel = new CompiledKernel {
12221200

benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ object benchmarks {
3333

3434
trait BenchmarkTensors
3535
extends StrictLogging
36-
with Tensors.UnsafeMathOptimizations
37-
with Tensors.SuppressWarnings
36+
with OpenCL.UnsafeMathOptimizations
37+
with OpenCL.SuppressWarnings
3838
with OpenCL.LogContextNotification
3939
with OpenCL.UseAllDevicesByType
4040
with OpenCL.GlobalExecutionContext

cpu/src/main/scala/com/thoughtworks/compute/cpu.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ import org.lwjgl.opencl.CL10.CL_DEVICE_TYPE_CPU
102102
*/
103103
object cpu
104104
extends StrictLogging
105-
with Tensors.UnsafeMathOptimizations
105+
with OpenCL.UnsafeMathOptimizations
106106
with OpenCL.LogContextNotification
107107
with OpenCL.GlobalExecutionContext
108108
with OpenCL.CommandQueuePool

gpu/src/main/scala/com/thoughtworks/compute/gpu.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import com.typesafe.scalalogging.StrictLogging
1414
*/
1515
object gpu
1616
extends StrictLogging
17-
with Tensors.UnsafeMathOptimizations
17+
with OpenCL.UnsafeMathOptimizations
1818
with OpenCL.LogContextNotification
1919
with OpenCL.GlobalExecutionContext
2020
with OpenCL.CommandQueuePool

0 commit comments

Comments
 (0)