|
3 | 3 | from typing import List |
4 | 4 |
|
5 | 5 |
|
6 | | -class TreeNode: |
7 | | - def __init__(self, start, end, val=0, left=None, right=None): |
8 | | - self.val = val |
9 | | - self.start = start |
10 | | - self.end = end |
11 | | - self.left = left |
12 | | - self.right = right |
| 6 | +class SegmentTree: |
| 7 | + def __init__(self, size: int): |
| 8 | + self.n = size |
| 9 | + self.segTree = [0 for _ in range(4 * size)] |
13 | 10 |
|
| 11 | + def update(self, idx: int, val: int, nodeIdx=0, start=0, end=None): |
| 12 | + if end is None: |
| 13 | + end = self.n - 1 |
| 14 | + if start == end: |
| 15 | + self.segTree[nodeIdx] += val |
| 16 | + return |
14 | 17 |
|
15 | | -class SegmentTree: |
16 | | - def __init__(self, n): |
17 | | - self.root = self.build(0, n - 1) |
18 | | - |
19 | | - def build(self, l, r): |
20 | | - if l == r: |
21 | | - return TreeNode(l, r, 0) |
22 | | - leftTree = self.build(l, (l + r) // 2) |
23 | | - rightTree = self.build((l + r) // 2 + 1, r) |
24 | | - return TreeNode(l, r, 0, leftTree, rightTree) |
25 | | - |
26 | | - def update(self, root, index, value): |
27 | | - if root.start == root.end == index: |
28 | | - root.val += value |
29 | | - return root.val |
30 | | - if root.start > index or root.end < index: |
31 | | - return root.val |
32 | | - root.val = self.update(root.left, index, value) + \ |
33 | | - self.update(root.right, index, value) |
34 | | - return root.val |
35 | | - |
36 | | - def query(self, root, l, r) -> int: |
37 | | - if root.start > r or root.end < l: |
| 18 | + mid = (start + end) // 2 |
| 19 | + leftIdx = 2 * nodeIdx + 1 |
| 20 | + rightIdx = 2 * nodeIdx + 2 |
| 21 | + |
| 22 | + if idx <= mid: |
| 23 | + self.update(idx, val, leftIdx, start, mid) |
| 24 | + else: |
| 25 | + self.update(idx, val, rightIdx, mid + 1, end) |
| 26 | + |
| 27 | + self.segTree[nodeIdx] = self.segTree[leftIdx] + self.segTree[rightIdx] |
| 28 | + |
| 29 | + def query(self, left: int, right: int, nodeIdx=0, start=0, end=None) -> int: |
| 30 | + if end is None: |
| 31 | + end = self.n - 1 |
| 32 | + if right < start or left > end: |
38 | 33 | return 0 |
39 | | - if l <= root.start and root.end <= r: |
40 | | - return root.val |
41 | | - return self.query(root.left, l, r) + self.query(root.right, l, r) |
| 34 | + if left <= start and end <= right: |
| 35 | + return self.segTree[nodeIdx] |
| 36 | + |
| 37 | + mid = (start + end) // 2 |
| 38 | + leftIdx = 2 * nodeIdx + 1 |
| 39 | + rightIdx = 2 * nodeIdx + 2 |
| 40 | + return self.query(left, right, leftIdx, start, mid) + self.query(left, right, rightIdx, mid + 1, end) |
42 | 41 |
|
43 | 42 |
|
44 | 43 | class Solution: |
45 | 44 | def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int: |
46 | | - presums = [0] |
| 45 | + preSums = [0] |
47 | 46 | for num in nums: |
48 | | - presums.append(presums[-1] + num) |
| 47 | + preSums.append(preSums[-1] + num) |
| 48 | + |
49 | 49 | allSums = set() |
50 | | - for presum in presums: |
51 | | - allSums.add(presum) |
52 | | - allSums.add(presum - lower) |
53 | | - allSums.add(presum - upper) |
54 | | - sortedSum = sorted(allSums) |
55 | | - rankMap = {val: idx for idx, val in enumerate(sortedSum)} |
56 | | - tree = SegmentTree(len(sortedSum)) |
57 | | - result = 0 |
58 | | - for presum in presums: |
59 | | - left = rankMap[presum - upper] |
60 | | - right = rankMap[presum - lower] |
61 | | - result += tree.query(tree.root, left, right) |
62 | | - tree.update(tree.root, rankMap[presum], 1) |
63 | | - return result |
| 50 | + for preSum in preSums: |
| 51 | + allSums.add(preSum) |
| 52 | + allSums.add(preSum - lower) |
| 53 | + allSums.add(preSum - upper) |
| 54 | + |
| 55 | + sortedSums = sorted(allSums) |
| 56 | + rankMap = {val: idx for idx, val in enumerate(sortedSums)} |
| 57 | + |
| 58 | + tree = SegmentTree(len(sortedSums)) |
| 59 | + count = 0 |
| 60 | + |
| 61 | + for preSum in preSums: |
| 62 | + left = rankMap[preSum - upper] |
| 63 | + right = rankMap[preSum - lower] |
| 64 | + count += tree.query(left, right) |
| 65 | + tree.update(rankMap[preSum], 1) |
| 66 | + |
| 67 | + return count |
64 | 68 |
|
65 | 69 |
|
66 | 70 | nums = [-2, 5, -1] |
|
0 commit comments