Skip to content

Commit 9b9a4f7

Browse files
yairmovmn-robot
authored andcommitted
Latency regularizer based on the LogisticSigmoid approach outlined in
https://arxiv.org/abs/2006.09581 PiperOrigin-RevId: 331889063
1 parent 4cfb342 commit 9b9a4f7

File tree

2 files changed

+80
-5
lines changed

2 files changed

+80
-5
lines changed

morph_net/network_regularizers/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ pytype_library(
113113
srcs = ["latency_regularizer.py"],
114114
deps = [
115115
":cost_calculator",
116+
":logistic_sigmoid_regularizer",
116117
":resource_function",
117118
"//third_party/py/morph_net/framework:batch_norm_source_op_handler",
118119
"//third_party/py/morph_net/framework:conv2d_transpose_source_op_handler",

morph_net/network_regularizers/latency_regularizer.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
"""A NetworkRegularizer that targets inference latency."""
22

3-
from __future__ import absolute_import
4-
from __future__ import division
5-
# [internal] enable type annotations
6-
from __future__ import print_function
3+
from typing import Type, List
74

85
from morph_net.framework import batch_norm_source_op_handler
96
from morph_net.framework import conv2d_transpose_source_op_handler as conv2d_transpose_handler
@@ -15,9 +12,86 @@
1512
from morph_net.framework import op_handlers
1613
from morph_net.framework import op_regularizer_manager as orm
1714
from morph_net.network_regularizers import cost_calculator
15+
from morph_net.network_regularizers import logistic_sigmoid_regularizer
1816
from morph_net.network_regularizers import resource_function
1917
import tensorflow.compat.v1 as tf
20-
from typing import Type, List
18+
19+
20+
class LogisticSigmoidLatencyRegularizer(
21+
logistic_sigmoid_regularizer.LogisticSigmoidRegularizer):
22+
"""A LogisticSigmoidRegularizer that targets Latency.
23+
24+
Args:
25+
output_boundary: An OpRegularizer will be created for all these
26+
operations, and recursively for all ops they depend on via data
27+
dependency that does not involve ops from input_boundary.
28+
batch_size: Integer batch size to calculate cost/loss for.
29+
regularize_on_mask: Bool. If True uses the binary mask as the
30+
regularization vector. Else uses the probability vector.
31+
alive_threshold: Float. Threshold below which values are considered dead.
32+
This can be used both when mask_as_alive_vector is True and then the
33+
threshold is used to binarize the sampled values and
34+
when mask_as_alive_vector is False, and then the threshold is on the
35+
channel probability.
36+
mask_as_alive_vector: Bool. If True use the thresholded sampled mask
37+
as the alive vector. Else, use thresholded probabilities from the
38+
logits.
39+
regularizer_decorator: A string, the name of the regularizer decorators to
40+
use. Supported decorators are listed in
41+
op_regularizer_decorator.SUPPORTED_DECORATORS.
42+
decorator_parameters: A dictionary of parameters to pass to the decorator
43+
factory. To be used only with decorators that requires parameters,
44+
otherwise use None.
45+
input_boundary: A list of ops that represent the input boundary of the
46+
subgraph being regularized (input boundary is not regularized).
47+
force_group: List of regex for ops that should be force-grouped. Each
48+
regex corresponds to a separate group. Use '|' operator to specify
49+
multiple patterns in a single regex. See op_regularizer_manager for more
50+
detail.
51+
regularizer_blacklist: List of regex for ops that should not be
52+
regularized. See op_regularizer_manager for more detail.
53+
"""
54+
55+
def __init__(
56+
self,
57+
output_boundary: List[tf.Operation],
58+
hardware,
59+
batch_size=1,
60+
regularize_on_mask=True,
61+
alive_threshold=0.1,
62+
mask_as_alive_vector=True,
63+
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
64+
decorator_parameters=None,
65+
input_boundary: List[tf.Operation] = None,
66+
force_group=None,
67+
regularizer_blacklist=None):
68+
69+
self._hardware = hardware
70+
self._batch_size = batch_size
71+
72+
super().__init__(
73+
output_boundary=output_boundary,
74+
regularize_on_mask=regularize_on_mask,
75+
alive_threshold=alive_threshold,
76+
mask_as_alive_vector=mask_as_alive_vector,
77+
regularizer_decorator=regularizer_decorator,
78+
decorator_parameters=decorator_parameters,
79+
input_boundary=input_boundary,
80+
force_group=force_group,
81+
regularizer_blacklist=regularizer_blacklist)
82+
83+
def get_calculator(self):
84+
return cost_calculator.CostCalculator(
85+
self._manager, resource_function.latency_function_factory(
86+
self._hardware, self._batch_size))
87+
88+
@property
89+
def name(self):
90+
return 'LogisticSigmoidLatency'
91+
92+
@property
93+
def cost_name(self):
94+
return self._hardware + ' Latency'
2195

2296

2397
class GammaLatencyRegularizer(generic_regularizers.NetworkRegularizer):

0 commit comments

Comments
 (0)