Skip to content

Commit 2ec56df

Browse files
authored
Fix preemption for sparse attention module and add attention sink. (#333)
* handle preempt in ESA and add init_window
1 parent 8622635 commit 2ec56df

File tree

7 files changed

+314
-234
lines changed

7 files changed

+314
-234
lines changed

docs/source/user-guide/sparse-attention/esa.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ ktc = KVTransferConfig(
3131
"init_window_sz": 1,
3232
"local_window_sz": 2,
3333
"min_blocks": 4,
34-
"sparse_ratio": 0.3,
35-
"retrieval_stride": 5,
34+
"sparse_ratio": 0.2,
35+
"retrieval_stride": 10,
3636
}
3737
},
3838
},
@@ -80,8 +80,8 @@ The following results were obtained using `Qwen2.5-14B-Instruct` under the speci
8080
"init_window_sz": 1,
8181
"local_window_sz": 2,
8282
"min_blocks": 4,
83-
"sparse_ratio": 0.3,
84-
"retrieval_stride": 5
83+
"sparse_ratio": 0.2,
84+
"retrieval_stride": 10
8585
}
8686
},
8787
```
@@ -92,5 +92,5 @@ The following results were obtained using `Qwen2.5-14B-Instruct` under the speci
9292
We use [LongBench](https://huggingface.co/datasets/zai-org/LongBench) to evaluate the accuracy of the ESA algorithm.
9393
| Dataset | F1-Score |
9494
|-------|-----------|
95-
| multifieldqa_zh | 59.4 |
96-
| dureader | 26.4 |
95+
| multifieldqa_zh | 64.28 |
96+
| dureader | 28.73 |

examples/offline_inference_esa.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def setup_environment_variables():
4545
sys.exit(1)
4646

4747
data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache")
48-
data_dir = input(
49-
"Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
50-
)
5148
if not os.path.isdir(data_dir):
49+
data_dir = input(
50+
"Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
51+
)
5252
create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ")
5353
if create.lower() == "y":
5454
os.makedirs(data_dir, exist_ok=True)
@@ -87,7 +87,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
8787
model=model,
8888
kv_transfer_config=ktc,
8989
max_model_len=32768,
90-
gpu_memory_utilization=0.6,
90+
gpu_memory_utilization=0.8,
9191
max_num_batched_tokens=30000,
9292
block_size=128,
9393
enforce_eager=True,
@@ -111,10 +111,14 @@ def print_output(
111111
start = time.time()
112112
outputs = llm.generate(prompt, sampling_params)
113113
print("-" * 50)
114+
lines = []
114115
for output in outputs:
115116
generated_text = output.outputs[0].text
116117
print(f"Generated text: {generated_text!r}")
118+
lines.append(generated_text + "\n")
117119
print(f"Generation took {time.time() - start:.2f} seconds, {req_str} request done.")
120+
with open("./newest_out.txt", "w") as f:
121+
f.writelines(lines)
118122
print("-" * 50)
119123

120124

@@ -140,24 +144,24 @@ def get_prompt(prompt):
140144

141145
with build_llm_with_uc(module_path, name, model) as llm:
142146
prompts = []
143-
batch_size = 5
147+
batch_size = 20
144148
assert os.path.isfile(
145149
path_to_dataset
146150
), f"Incorrect dataset path. Please specify the dataset path by `export DATASET_PATH=/path/to/longbench/multifieldqa_zh.jsonl`"
147151
with open(path_to_dataset, "r") as f:
148-
for _ in range(batch_size):
149-
line = f.readline()
150-
if not line:
151-
break
152-
data = json.loads(line)
153-
context = data["context"]
154-
question = data["input"]
155-
prompts.append(get_prompt(f"{context}\n\n{question}"))
156-
157-
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=100)
152+
lines = f.readlines()
153+
for i in range(batch_size):
154+
line = lines[i]
155+
data = json.loads(line)
156+
context = data["context"]
157+
question = data["input"]
158+
prompts.append(get_prompt(f"{context}\n\n{question}"))
159+
160+
sampling_params = SamplingParams(
161+
temperature=0, top_p=0.95, max_tokens=256, ignore_eos=False
162+
)
158163

159164
print_output(llm, prompts, sampling_params, "first")
160-
print_output(llm, prompts, sampling_params, "second")
161165

162166

163167
if __name__ == "__main__":

0 commit comments

Comments
 (0)