Skip to content

Commit 5c0d969

Browse files
EddyLXJmeta-codesync[bot]
authored andcommitted
Support no eviction in Feature score eviction policy (#5059)
Summary: Pull Request resolved: #5059 X-link: meta-pytorch/torchrec#3488 X-link: https://github.com/facebookresearch/FBGEMM/pull/2068 As title If one table is using feature score eviction in one tbe, then all tables in this tbe need to use the same policy. Feature score eviction can support ttl based eviction now. This diff is adding support no eviction in feature score eviction policy. Reviewed By: emlin Differential Revision: D84660528 fbshipit-source-id: 28d85126da30b2c357c7f99b6535be381a76f2ce
1 parent dc333ed commit 5c0d969

File tree

8 files changed

+220
-6
lines changed

8 files changed

+220
-6
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ class EvictionPolicy(NamedTuple):
120120
eviction_free_mem_check_interval_batch: Optional[int] = (
121121
None # Number of batches between checks for free memory threshold when using free_mem trigger mode.
122122
)
123+
enable_eviction_for_feature_score_eviction_policy: Optional[list[bool]] = (
124+
None # enable eviction if eviction policy is feature score, false means no eviction
125+
)
123126

124127
def validate(self) -> None:
125128
assert self.eviction_trigger_mode in [0, 1, 2, 3, 4, 5], (
@@ -217,13 +220,17 @@ def validate(self) -> None:
217220
"threshold_calculation_bucket_num must be set if eviction_strategy is 5,"
218221
f"actual {self.threshold_calculation_bucket_num}"
219222
)
223+
assert self.enable_eviction_for_feature_score_eviction_policy is not None, (
224+
"enable_eviction_for_feature_score_eviction_policy must be set if eviction_strategy is 5,"
225+
f"actual {self.enable_eviction_for_feature_score_eviction_policy}"
226+
)
220227
assert (
221-
len(self.training_id_keep_count)
228+
len(self.enable_eviction_for_feature_score_eviction_policy)
229+
== len(self.training_id_keep_count)
222230
== len(self.feature_score_counter_decay_rates)
223-
== len(self.training_id_eviction_trigger_count)
224231
), (
225-
"feature_score_thresholds, training_id_eviction_trigger_count and training_id_keep_count must have the same length, "
226-
f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.training_id_eviction_trigger_count}"
232+
"feature_score_thresholds, enable_eviction_for_feature_score_eviction_policy, and training_id_keep_count must have the same length, "
233+
f"actual {self.training_id_keep_count} vs {self.feature_score_counter_decay_rates} vs {self.enable_eviction_for_feature_score_eviction_policy}"
227234
)
228235

229236

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,15 @@ def __init__(
707707
# If trigger mode is free_mem(5), populate config
708708
self.set_free_mem_eviction_trigger_config(eviction_policy)
709709

710+
enable_eviction_for_feature_score_eviction_policy = ( # pytorch api in c++ doesn't support vertor<bool>, convert to int here, 0: no eviction 1: eviction
711+
[
712+
int(x)
713+
for x in eviction_policy.enable_eviction_for_feature_score_eviction_policy
714+
]
715+
if eviction_policy.enable_eviction_for_feature_score_eviction_policy
716+
is not None
717+
else None
718+
)
710719
# Please refer to https://fburl.com/gdoc/nuupjwqq for the following eviction parameters.
711720
eviction_config = torch.classes.fbgemm.FeatureEvictConfig(
712721
eviction_policy.eviction_trigger_mode, # eviction is disabled, 0: disabled, 1: iteration, 2: mem_util, 3: manual, 4: id count
@@ -719,6 +728,7 @@ def __init__(
719728
eviction_policy.feature_score_counter_decay_rates, # feature_score_counter_decay_rates for each table if eviction strategy is feature score
720729
eviction_policy.training_id_eviction_trigger_count, # training_id_eviction_trigger_count for each table
721730
eviction_policy.training_id_keep_count, # training_id_keep_count for each table
731+
enable_eviction_for_feature_score_eviction_policy, # no eviction setting for feature score eviction policy
722732
eviction_policy.l2_weight_thresholds, # l2_weight_thresholds for each table if eviction strategy is feature l2 norm
723733
table_dims.tolist() if table_dims is not None else None,
724734
eviction_policy.threshold_calculation_bucket_stride, # threshold_calculation_bucket_stride if eviction strategy is feature score

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_inference_wrapper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ void DramKVEmbeddingInferenceWrapper::init(
7676
std::nullopt /* feature_score_counter_decay_rates */,
7777
std::nullopt /* training_id_eviction_trigger_count */,
7878
std::nullopt /* training_id_keep_count */,
79+
std::nullopt /* enable_eviction_for_feature_score_eviction_policy */,
7980
std::nullopt /* l2_weight_thresholds */,
8081
std::nullopt /* embedding_dims */,
8182
std::nullopt /* threshold_calculation_bucket_stride */,

fbgemm_gpu/src/dram_kv_embedding_cache/feature_evict.h

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
110110
std::optional<std::vector<double>> feature_score_counter_decay_rates,
111111
std::optional<std::vector<int64_t>> training_id_eviction_trigger_count,
112112
std::optional<std::vector<int64_t>> training_id_keep_count,
113+
std::optional<std::vector<int8_t>>
114+
enable_eviction_for_feature_score_eviction_policy, // 0: no eviction,
115+
// 1: evict
113116
std::optional<std::vector<double>> l2_weight_thresholds,
114117
std::optional<std::vector<int64_t>> embedding_dims,
115118
std::optional<double> threshold_calculation_bucket_stride = 0.2,
@@ -129,6 +132,8 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
129132
training_id_eviction_trigger_count_(
130133
std::move(training_id_eviction_trigger_count)),
131134
training_id_keep_count_(std::move(training_id_keep_count)),
135+
enable_eviction_for_feature_score_eviction_policy_(
136+
std::move(enable_eviction_for_feature_score_eviction_policy)),
132137
l2_weight_thresholds_(l2_weight_thresholds),
133138
embedding_dims_(embedding_dims),
134139
threshold_calculation_bucket_stride_(
@@ -169,10 +174,17 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
169174
CHECK(
170175
training_id_eviction_trigger_count_.has_value() &&
171176
!training_id_eviction_trigger_count_.value().empty());
177+
CHECK(enable_eviction_for_feature_score_eviction_policy_.has_value());
178+
const auto& enable_eviction_vec =
179+
enable_eviction_for_feature_score_eviction_policy_.value();
172180
const auto& vec = training_id_eviction_trigger_count_.value();
173181
eviction_trigger_stats_log = ", training_id_eviction_trigger_count: [";
174182
total_id_eviction_trigger_count_ = 0;
175183
for (size_t i = 0; i < vec.size(); ++i) {
184+
if (enable_eviction_vec[i] == 0) {
185+
throw std::runtime_error(
186+
"ID_COUNT trigger mode doesn't not support enable_eviction=False, please use FREE_MEM trigger mode instead");
187+
}
176188
total_id_eviction_trigger_count_ =
177189
total_id_eviction_trigger_count_.value() + vec[i];
178190
if (vec[i] <= 0) {
@@ -212,6 +224,7 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
212224
CHECK(threshold_calculation_bucket_stride_.has_value());
213225
CHECK(threshold_calculation_bucket_num_.has_value());
214226
CHECK(ttls_in_mins_.has_value());
227+
CHECK(enable_eviction_for_feature_score_eviction_policy_.has_value());
215228
LOG(INFO) << "eviction config, trigger mode:"
216229
<< to_string(trigger_mode_) << eviction_trigger_stats_log
217230
<< ", strategy: " << to_string(trigger_strategy_)
@@ -223,7 +236,9 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
223236
<< ", threshold_calculation_bucket_num: "
224237
<< threshold_calculation_bucket_num_.value()
225238
<< ", feature_score_counter_decay_rates: "
226-
<< feature_score_counter_decay_rates_.value();
239+
<< feature_score_counter_decay_rates_.value()
240+
<< ", enable_eviction_for_feature_score_eviction_policy: "
241+
<< enable_eviction_for_feature_score_eviction_policy_.value();
227242
return;
228243
}
229244

@@ -281,6 +296,8 @@ struct FeatureEvictConfig : public torch::jit::CustomClassHolder {
281296
std::optional<std::vector<double>> feature_score_counter_decay_rates_;
282297
std::optional<std::vector<int64_t>> training_id_eviction_trigger_count_;
283298
std::optional<std::vector<int64_t>> training_id_keep_count_;
299+
std::optional<std::vector<int8_t>>
300+
enable_eviction_for_feature_score_eviction_policy_;
284301
std::optional<int64_t> total_id_eviction_trigger_count_;
285302
std::optional<std::vector<double>> l2_weight_thresholds_;
286303
std::optional<std::vector<int64_t>> embedding_dims_;
@@ -984,6 +1001,8 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
9841001
const std::vector<int64_t>& training_id_eviction_trigger_count,
9851002
const std::vector<int64_t>& training_id_keep_count,
9861003
const std::vector<int64_t>& ttls_in_mins,
1004+
const std::vector<int8_t>&
1005+
enable_eviction_for_feature_score_eviction_policy,
9871006
const double threshold_calculation_bucket_stride,
9881007
const int64_t threshold_calculation_bucket_num,
9891008
int64_t interval_for_insufficient_eviction_s,
@@ -1003,6 +1022,8 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
10031022
training_id_eviction_trigger_count_(training_id_eviction_trigger_count),
10041023
training_id_keep_count_(training_id_keep_count),
10051024
ttls_in_mins_(ttls_in_mins),
1025+
enable_eviction_for_feature_score_eviction_policy_(
1026+
enable_eviction_for_feature_score_eviction_policy),
10061027
threshold_calculation_bucket_stride_(
10071028
threshold_calculation_bucket_stride),
10081029
num_buckets_(threshold_calculation_bucket_num),
@@ -1071,6 +1092,13 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
10711092
protected:
10721093
bool evict_block(weight_type* block, int sub_table_id, int shard_id)
10731094
override {
1095+
int8_t enable_eviction =
1096+
enable_eviction_for_feature_score_eviction_policy_[sub_table_id];
1097+
if (enable_eviction == 0) {
1098+
// If enable_eviction is set to 0, we don't evict any block.
1099+
return false;
1100+
}
1101+
10741102
double ttls_threshold = ttls_in_mins_[sub_table_id];
10751103
if (ttls_threshold > 0) {
10761104
auto current_time = FixedBlockPool::current_timestamp();
@@ -1145,6 +1173,15 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
11451173

11461174
void compute_thresholds_from_buckets() {
11471175
for (size_t table_id = 0; table_id < num_tables_; ++table_id) {
1176+
int8_t enable_eviction =
1177+
enable_eviction_for_feature_score_eviction_policy_[table_id];
1178+
if (enable_eviction == 0) {
1179+
// If enable_eviction is set to 0, we don't evict any block.
1180+
thresholds_[table_id] = 0.0;
1181+
evict_modes_[table_id] = EvictMode::NONE;
1182+
continue;
1183+
}
1184+
11481185
int64_t total = 0;
11491186

11501187
if (ttls_in_mins_[table_id] > 0) {
@@ -1209,7 +1246,8 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
12091246
<< " threshold bucket: " << threshold_bucket
12101247
<< " actual evict count: " << acc_count
12111248
<< " target evict count: " << evict_count
1212-
<< " total count: " << total;
1249+
<< " total count: " << total
1250+
<< " evict mode: " << to_string(evict_modes_[table_id]);
12131251

12141252
for (int table_id = 0; table_id < num_tables_; ++table_id) {
12151253
this->metrics_.eviction_threshold_with_dry_run[table_id] =
@@ -1226,6 +1264,16 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
12261264
THRESHOLD // blocks with scores below the computed threshold will be
12271265
// evicted
12281266
};
1267+
inline std::string to_string(EvictMode mode) {
1268+
switch (mode) {
1269+
case EvictMode::NONE:
1270+
return "NONE";
1271+
case EvictMode::ONLY_ZERO:
1272+
return "ONLY_ZERO";
1273+
case EvictMode::THRESHOLD:
1274+
return "THRESHOLD";
1275+
}
1276+
}
12291277
std::vector<EvictMode> evict_modes_;
12301278

12311279
const int num_tables_ = static_cast<int>(this->sub_table_hash_cumsum_.size());
@@ -1240,6 +1288,7 @@ class FeatureScoreBasedEvict : public FeatureEvict<weight_type> {
12401288
// eviction.
12411289

12421290
const std::vector<int64_t>& ttls_in_mins_; // Time-to-live for eviction.
1291+
const std::vector<int8_t>& enable_eviction_for_feature_score_eviction_policy_;
12431292
std::vector<std::vector<std::vector<size_t>>>
12441293
local_buckets_per_shard_per_table_;
12451294
std::vector<std::vector<size_t>> local_blocks_num_per_shard_per_table_;
@@ -1489,6 +1538,7 @@ std::unique_ptr<FeatureEvict<weight_type>> create_feature_evict(
14891538
config->training_id_eviction_trigger_count_.value(),
14901539
config->training_id_keep_count_.value(),
14911540
config->ttls_in_mins_.value(),
1541+
config->enable_eviction_for_feature_score_eviction_policy_.value(),
14921542
config->threshold_calculation_bucket_stride_.value(),
14931543
config->threshold_calculation_bucket_num_.value(),
14941544
config->interval_for_insufficient_eviction_s_,

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,7 @@ static auto feature_evict_config =
737737
std::optional<std::vector<double>>,
738738
std::optional<std::vector<int64_t>>,
739739
std::optional<std::vector<int64_t>>,
740+
std::optional<std::vector<int8_t>>,
740741
std::optional<std::vector<double>>,
741742
std::optional<std::vector<int64_t>>,
742743
std::optional<double>,
@@ -756,6 +757,9 @@ static auto feature_evict_config =
756757
torch::arg("feature_score_counter_decay_rates") = std::nullopt,
757758
torch::arg("training_id_eviction_trigger_count") = std::nullopt,
758759
torch::arg("training_id_keep_count") = std::nullopt,
760+
torch::arg(
761+
"enable_eviction_for_feature_score_eviction_policy") =
762+
std::nullopt,
759763
torch::arg("l2_weight_thresholds") = std::nullopt,
760764
torch::arg("embedding_dims") = std::nullopt,
761765
torch::arg("threshold_calculation_bucket_stride") = 0.2,

0 commit comments

Comments
 (0)