1
2
3
4
5 package math.geom2d.point;
6
7
8 import java.util.ArrayList;
9 import java.util.Collection;
10 import java.util.Collections;
11 import java.util.Comparator;
12 import java.util.List;
13
14 import math.geom2d.Box2D;
15 import math.geom2d.Point2D;
16 import math.geom2d.Vector2D;
17 import math.geom2d.line.StraightLine2D;
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32 public class KDTree2D {
33
34 public class Node{
35 private Point2D point;
36 private Node left;
37 private Node right;
38
39 public Node(Point2D point){
40 this.point = point;
41 this.left = null;
42 this.right = null;
43 }
44
45 public Node(Point2D point, Node left, Node right){
46 this.point = point;
47 this.left = left;
48 this.right = right;
49 }
50
51 public Point2D getPoint() {
52 return point;
53 }
54
55 public Node getLeftChild() {
56 return left;
57 }
58
59 public Node getRightChild() {
60 return right;
61 }
62
63 public boolean isLeaf() {
64 return left==null && right==null;
65 }
66 }
67
68 private class XComparator implements Comparator<Point2D> {
69
70
71
72 public int compare(Point2D p1, Point2D p2){
73 if(p1.getX()<p2.getX())
74 return -1;
75 if(p1.getX()>p2.getX())
76 return +1;
77 return Double.compare(p1.getY(), p2.getY());
78 }
79 }
80
81 private class YComparator implements Comparator<Point2D> {
82
83
84
85 public int compare(Point2D p1, Point2D p2){
86 if(p1.getY()<p2.getY())
87 return -1;
88 if(p1.getY()>p2.getY())
89 return +1;
90 return Double.compare(p1.getX(), p2.getX());
91 }
92 }
93
94 private Node root;
95
96 private Comparator<Point2D> xComparator;
97 private Comparator<Point2D> yComparator;
98
99 public KDTree2D(ArrayList<Point2D> points) {
100 this.xComparator = new XComparator();
101 this.yComparator = new YComparator();
102 root = makeTree(points, 0);
103 }
104
105 private Node makeTree(List<Point2D> points, int depth) {
106
107 if(points.size()==0)
108 return null;
109
110
111 int dir = depth%2;
112
113
114 if(dir==0){
115
116 Collections.sort(points, xComparator);
117 }else{
118
119 Collections.sort(points, yComparator);
120 }
121
122 int n = points.size();
123 int med = n/2;
124
125 return new Node(
126 points.get(med),
127 makeTree(points.subList(0, med), depth+1),
128 makeTree(points.subList(med+1, n), depth+1));
129 }
130
131 public Node getRoot() {
132 return root;
133 }
134
135 public boolean contains(Point2D value){
136 return contains(value, root, 0);
137 }
138
139 private boolean contains(Point2D point, Node node, int depth){
140 if(node==null) return false;
141
142
143 int dir = depth%2;
144
145
146 int res;
147 if(dir==0){
148
149 res = xComparator.compare(point, node.point);
150 }else{
151
152 res = yComparator.compare(point, node.point);
153 }
154
155 if(res<0)
156 return contains(point, node.left, depth+1);
157 if(res>0)
158 return contains(point, node.right, depth+1);
159
160 return true;
161 }
162
163 public Node getNode(Point2D point) {
164 return getNode(point, root, 0);
165 }
166
167 private Node getNode(Point2D point, Node node, int depth){
168 if(node==null) return null;
169
170 int dir = depth%2;
171
172
173 int res;
174 if(dir==0){
175
176 res = xComparator.compare(point, node.point);
177 }else{
178
179 res = yComparator.compare(point, node.point);
180 }
181
182 if(res<0)
183 return getNode(point, node.left, depth+1);
184 if(res>0)
185 return getNode(point, node.right, depth+1);
186
187 return node;
188 }
189
190 public void add(Point2D point){
191 add(point, root, 0);
192 }
193
194 private void add(Point2D point, Node node, int depth) {
195
196 int dir = depth%2;
197
198
199 int res;
200 if(dir==0){
201
202 res = xComparator.compare(point, node.point);
203 }else{
204
205 res = yComparator.compare(point, node.point);
206 }
207
208 if(res<0){
209 if(node.left==null)
210 node.left = new Node(point);
211 else
212 add(point, node.left, depth+1);
213 }
214 if(res>0)
215 if(node.right==null)
216 node.right = new Node(point);
217 else
218 add(point, node.right, depth+1);
219 }
220
221 public Collection<Point2D> rangeSearch(Box2D range) {
222 ArrayList<Point2D> points = new ArrayList<Point2D>();
223 rangeSearch(range, points, root, 0);
224 return points;
225 }
226
227
228
229
230 private void rangeSearch(Box2D range,
231 Collection<Point2D> points, Node node, int depth) {
232 if(node==null)
233 return;
234
235
236 Point2D point = node.getPoint();
237 double x = point.getX();
238 double y = point.getY();
239
240
241 boolean tx1 = range.getMinX()<x;
242 boolean ty1 = range.getMinY()<y;
243 boolean tx2 = x<=range.getMaxX();
244 boolean ty2 = y<=range.getMaxY();
245
246
247 if(tx1 && tx2 && ty1 && ty2)
248 points.add(point);
249
250
251 int dir = depth%2;
252
253 if(dir==0 ? tx1 : ty1)
254 rangeSearch(range, points, node.left, depth+1);
255 if(dir==0 ? tx2 : ty2)
256 rangeSearch(range, points, node.right, depth+1);
257 }
258
259
260 public Point2D nearestNeighbor(Point2D point) {
261 return nearestNeighbor(point, root, root, 0).getPoint();
262 }
263
264
265
266
267
268 private Node nearestNeighbor(Point2D point, Node candidate, Node node,
269 int depth) {
270
271 double distCand = candidate.point.getDistance(point);
272 double dist = node.point.getDistance(point);
273 if(dist<distCand){
274 candidate = node;
275 }
276
277
278 int dir = depth%2;
279
280 Node node1, node2;
281
282
283
284 Point2D anchor = node.getPoint();
285 StraightLine2D line;
286 if(dir==0){
287 boolean b = point.getX()<anchor.getX();
288 node1 = b ? node.left : node.right;
289 node2 = b ? node.right : node.left;
290 line = StraightLine2D.create(anchor, new Vector2D(0, 1));
291 } else {
292 boolean b = point.getY()<anchor.getY();
293 node1 = b ? node.left : node.right;
294 node2 = b ? node.right : node.left;
295 line = StraightLine2D.create(anchor, new Vector2D(1, 0));
296 }
297
298 if(node1!=null) {
299
300 candidate = nearestNeighbor(point, candidate, node1, depth+1);
301
302
303 distCand = candidate.getPoint().getDistance(point);
304 }
305
306
307
308 if(line.getDistance(point)<distCand && node2!=null) {
309 candidate = nearestNeighbor(point, candidate, node2, depth+1);
310 }
311
312 return candidate;
313 }
314
315
316
317
318
319 public static void main(String[] args){
320 int n = 3;
321 ArrayList<Point2D> points = new ArrayList<Point2D>(n);
322 points.add(new Point2D(5, 5));
323 points.add(new Point2D(10, 10));
324 points.add(new Point2D(20, 20));
325
326 System.out.println("Check KDTree2D");
327
328 KDTree2D tree = new KDTree2D(points);
329
330 System.out.println(tree.contains(new Point2D(5, 5)));
331 System.out.println(tree.contains(new Point2D(6, 5)));
332 }
333
334 }