Skip to content

Commit f1b684f

Browse files
authored
Support quatization of loadConstantND (#1933)
1 parent 63eb725 commit f1b684f

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

coremltools/models/neural_network/quantization_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self):
7878
"scale",
7979
"bias",
8080
"loadConstant",
81+
"loadConstantND",
8182
"simpleRecurrent",
8283
"gru",
8384
"uniDirectionalLSTM",
@@ -903,6 +904,13 @@ def _quantize_nn_spec(nn_spec, nbits, qm, **kwargs):
903904
layer.loadConstant.data, nbits, qm, shape=(nw,), **kwargs
904905
)
905906

907+
# LoadConstantND layer
908+
elif layer_type == "loadConstantND":
909+
nw = _np.prod(layer.loadConstantND.shape)
910+
_quantize_wp_field(
911+
layer.loadConstantND.data, nbits, qm, shape=(nw,), **kwargs
912+
)
913+
906914
# Simple Recurrent
907915
elif layer_type == "simpleRecurrent":
908916
i_size = layer.simpleRecurrent.inputVectorSize

coremltools/test/neural_network/test_quantization.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,62 @@ def test_embeddingND_quantize(compute_units):
562562
def test_embeddingND_quantize_CPU_and_NE(self):
563563
self.test_embeddingND_quantize(ComputeUnit.CPU_AND_NE)
564564

565+
@staticmethod
566+
@pytest.mark.parametrize(
567+
"compute_units", [ComputeUnit.ALL, ComputeUnit.CPU_AND_GPU, ComputeUnit.CPU_ONLY]
568+
)
569+
def test_loadConstantND_quantize(compute_units):
570+
input_features = [("data", datatypes.Array(10, 1))]
571+
output_features = [("output", None)]
572+
builder = neural_network.NeuralNetworkBuilder(
573+
input_features, output_features, disable_rank5_shape_mapping=True
574+
)
575+
576+
output_shape = [10, 10]
577+
ori_value = np.random.randint(0, 200, output_shape).astype(np.float32)
578+
builder.add_load_constant_nd(
579+
name="load_constant_nd",
580+
output_name="constant",
581+
constant_value=ori_value,
582+
shape=output_shape)
583+
584+
builder.add_broadcast_to_dynamic(
585+
name="broadcast_to_dynamic", input_names=["constant", "data"], output_name="output"
586+
)
587+
588+
spec = builder.spec
589+
model_fp32 = coremltools.models.MLModel(spec, compute_units=compute_units)
590+
assert len(spec.neuralNetwork.layers[0].loadConstantND.data.floatValue) == 100
591+
592+
# quantize to FP16
593+
model_fp16 = quantization_utils.quantize_weights(model_fp32, nbits=16)
594+
assert model_fp16.compute_unit == compute_units
595+
spec_fp16 = model_fp16.get_spec()
596+
assert len(spec_fp16.neuralNetwork.layers[0].loadConstantND.data.floatValue) == 0
597+
assert len(spec_fp16.neuralNetwork.layers[0].loadConstantND.data.float16Value) == 2 * 100
598+
599+
# quantize to uint8
600+
model_uint8 = quantization_utils.quantize_weights(model_fp32, nbits=8)
601+
assert model_uint8.compute_unit == compute_units
602+
spec_uint8 = model_uint8.get_spec()
603+
assert len(spec_uint8.neuralNetwork.layers[0].loadConstantND.data.floatValue) == 0
604+
assert len(spec_uint8.neuralNetwork.layers[0].loadConstantND.data.float16Value) == 0
605+
assert len(spec_uint8.neuralNetwork.layers[0].loadConstantND.data.rawValue) == 100
606+
607+
# quantize to uint5
608+
model_uint5 = quantization_utils.quantize_weights(model_fp32, nbits=5)
609+
assert model_uint5.compute_unit == compute_units
610+
spec_uint5 = model_uint5.get_spec()
611+
assert len(spec_uint5.neuralNetwork.layers[0].loadConstantND.data.floatValue) == 0
612+
assert len(spec_uint5.neuralNetwork.layers[0].loadConstantND.data.float16Value) == 0
613+
assert len(spec_uint5.neuralNetwork.layers[0].loadConstantND.data.rawValue) == 63 # 63 = ceil(5*100/8)
614+
615+
@unittest.skipIf(coremltools.utils._macos_version() < (13, 0),
616+
'ComputeUnit.CPU_AND_NE is only available on macOS >= 13.0'
617+
)
618+
def test_loadConstantND_quantize_CPU_and_NE(self):
619+
self.test_loadConstantND_quantize(ComputeUnit.CPU_AND_NE)
620+
565621

566622
class TestKMeansLookup:
567623
@pytest.mark.parametrize("weightShape, dtype",

0 commit comments

Comments
 (0)