Segment Tree

A segment tree also known as a statistic tree is a tree data structure used for storing information about intervals, or segments. It allows querying which of the stored segments contain a given point. It is, in principle, a static structure; that is, it’s a structure that cannot be modified once it’s built. A similar data structure is the interval tree.

m2o4Og.png

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
public class SegmentTreeNode {
private int start;
private int end;
private int sum;
private SegmentTreeNode left;
private SegmentTreeNode right;

public SegmentTreeNode(int start, int end, int sum) {
this.start = start;
this.end = end;
this.sum = sum;
}
}

public class SegmentTree {
private SegmentTreeNode root;

public SegmentTree (int[] nums, int start, int end) {
this.root = build(nums, start, end);
}

private SegmentTreeNode build(int[] nums, int start, int end) {
if (start > end)
return null;

SegmentTreeNode node = new SegmentTreeNode(start, end, 0);

if (start != end) {
int mid = (start + end) / 2;
node.left = build(nums, start, mid);
node.right = build(nums, mid + 1, end);
node.sum = node.left.sum + node.right.sum;
}
else
node.sum = nums[start];

return node;
}

public void modify(int index, int value) {
modify(root, index, value);
}

private void modify(SegmentTreeNode node, int index, int value) {
if (node.start == index && node.end == index) {
node.sum = value;
return;
}

int mid = (node.start + node.end) / 2;

if (index <= mid)
modify(node.left, index, value);
else
modify(node.right, index, value);

node.sum = node.left.sum + node.right.sum;
}

public int query(int start, int end) {
return query(root, start, end);
}

private int query(SegmentTreeNode node, int start, int end) {
if (start == node.start && end == node.end)
return node.sum;

int mid = (node.start + node.end) / 2;

if (end <= mid)
return query(node.left, start, end);
else if (start >= mid + 1)
return query(node.right, start, end);
else
return query(node.left, start, mid) + query(node.right, mid + 1, end);
}
}