|
1 | 1 | """A NetworkRegularizer that targets inference latency.""" |
2 | 2 |
|
3 | | -from __future__ import absolute_import |
4 | | -from __future__ import division |
5 | | -# [internal] enable type annotations |
6 | | -from __future__ import print_function |
| 3 | +from typing import Type, List |
7 | 4 |
|
8 | 5 | from morph_net.framework import batch_norm_source_op_handler |
9 | 6 | from morph_net.framework import conv2d_transpose_source_op_handler as conv2d_transpose_handler |
|
15 | 12 | from morph_net.framework import op_handlers |
16 | 13 | from morph_net.framework import op_regularizer_manager as orm |
17 | 14 | from morph_net.network_regularizers import cost_calculator |
| 15 | +from morph_net.network_regularizers import logistic_sigmoid_regularizer |
18 | 16 | from morph_net.network_regularizers import resource_function |
19 | 17 | import tensorflow.compat.v1 as tf |
20 | | -from typing import Type, List |
| 18 | + |
| 19 | + |
| 20 | +class LogisticSigmoidLatencyRegularizer( |
| 21 | + logistic_sigmoid_regularizer.LogisticSigmoidRegularizer): |
| 22 | + """A LogisticSigmoidRegularizer that targets Latency. |
| 23 | +
|
| 24 | + Args: |
| 25 | + output_boundary: An OpRegularizer will be created for all these |
| 26 | + operations, and recursively for all ops they depend on via data |
| 27 | + dependency that does not involve ops from input_boundary. |
| 28 | + batch_size: Integer batch size to calculate cost/loss for. |
| 29 | + regularize_on_mask: Bool. If True uses the binary mask as the |
| 30 | + regularization vector. Else uses the probability vector. |
| 31 | + alive_threshold: Float. Threshold below which values are considered dead. |
| 32 | + This can be used both when mask_as_alive_vector is True and then the |
| 33 | + threshold is used to binarize the sampled values and |
| 34 | + when mask_as_alive_vector is False, and then the threshold is on the |
| 35 | + channel probability. |
| 36 | + mask_as_alive_vector: Bool. If True use the thresholded sampled mask |
| 37 | + as the alive vector. Else, use thresholded probabilities from the |
| 38 | + logits. |
| 39 | + regularizer_decorator: A string, the name of the regularizer decorators to |
| 40 | + use. Supported decorators are listed in |
| 41 | + op_regularizer_decorator.SUPPORTED_DECORATORS. |
| 42 | + decorator_parameters: A dictionary of parameters to pass to the decorator |
| 43 | + factory. To be used only with decorators that requires parameters, |
| 44 | + otherwise use None. |
| 45 | + input_boundary: A list of ops that represent the input boundary of the |
| 46 | + subgraph being regularized (input boundary is not regularized). |
| 47 | + force_group: List of regex for ops that should be force-grouped. Each |
| 48 | + regex corresponds to a separate group. Use '|' operator to specify |
| 49 | + multiple patterns in a single regex. See op_regularizer_manager for more |
| 50 | + detail. |
| 51 | + regularizer_blacklist: List of regex for ops that should not be |
| 52 | + regularized. See op_regularizer_manager for more detail. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + output_boundary: List[tf.Operation], |
| 58 | + hardware, |
| 59 | + batch_size=1, |
| 60 | + regularize_on_mask=True, |
| 61 | + alive_threshold=0.1, |
| 62 | + mask_as_alive_vector=True, |
| 63 | + regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None, |
| 64 | + decorator_parameters=None, |
| 65 | + input_boundary: List[tf.Operation] = None, |
| 66 | + force_group=None, |
| 67 | + regularizer_blacklist=None): |
| 68 | + |
| 69 | + self._hardware = hardware |
| 70 | + self._batch_size = batch_size |
| 71 | + |
| 72 | + super().__init__( |
| 73 | + output_boundary=output_boundary, |
| 74 | + regularize_on_mask=regularize_on_mask, |
| 75 | + alive_threshold=alive_threshold, |
| 76 | + mask_as_alive_vector=mask_as_alive_vector, |
| 77 | + regularizer_decorator=regularizer_decorator, |
| 78 | + decorator_parameters=decorator_parameters, |
| 79 | + input_boundary=input_boundary, |
| 80 | + force_group=force_group, |
| 81 | + regularizer_blacklist=regularizer_blacklist) |
| 82 | + |
| 83 | + def get_calculator(self): |
| 84 | + return cost_calculator.CostCalculator( |
| 85 | + self._manager, resource_function.latency_function_factory( |
| 86 | + self._hardware, self._batch_size)) |
| 87 | + |
| 88 | + @property |
| 89 | + def name(self): |
| 90 | + return 'LogisticSigmoidLatency' |
| 91 | + |
| 92 | + @property |
| 93 | + def cost_name(self): |
| 94 | + return self._hardware + ' Latency' |
21 | 95 |
|
22 | 96 |
|
23 | 97 | class GammaLatencyRegularizer(generic_regularizers.NetworkRegularizer): |
|
0 commit comments