View Javadoc

1    /**
2    * file KDTree2D.java
3    */
4   
5   package math.geom2d.point;
6   
7   
8   import java.util.ArrayList;
9   import java.util.Collection;
10  import java.util.Collections;
11  import java.util.Comparator;
12  import java.util.List;
13  
14  import math.geom2d.Box2D;
15  import math.geom2d.Point2D;
16  import math.geom2d.Vector2D;
17  import math.geom2d.line.StraightLine2D;
18  
19  
20  /**
21   * A data structure for storing a great number of points. During construction
22   * of the tree, median point in current coordinate is chosen for each step,
23   * ensuring the final tree is balanced. The cost for retrieving a point is 
24   * O(log n).<br>
25   * The cost for building the tree is O(n log^2 n), that can take some time for
26   * large points sets.<br>
27   * This implementation is semi-dynamic: points can be added, but can not be
28   * removed.
29   * @author dlegland
30   *
31   */
32  public class KDTree2D {
33      //TODO: make KDTree2D implements PointSet2D
34      public class Node{
35          private Point2D point;
36          private Node left;
37          private Node right;
38          
39          public Node(Point2D point){
40              this.point  = point;
41              this.left   = null;
42              this.right  = null;
43          }
44  
45          public Node(Point2D point, Node left, Node right){
46              this.point  = point;
47              this.left   = left;
48              this.right  = right;
49          }
50          
51          public Point2D getPoint() {
52              return point;
53          }
54          
55          public Node getLeftChild() {
56              return left;
57          }
58          
59          public Node getRightChild() {
60              return right;
61          }
62          
63          public boolean isLeaf() {
64              return left==null && right==null;
65          }
66      }
67          
68      private class XComparator implements Comparator<Point2D> {
69          /* (non-Javadoc)
70           * @see java.util.Comparator#compare(java.lang.Object, java.lang.Object)
71           */
72          public int compare(Point2D p1, Point2D p2){
73              if(p1.getX()<p2.getX())
74                  return -1;
75              if(p1.getX()>p2.getX())
76                  return +1;
77              return Double.compare(p1.getY(), p2.getY());
78          }
79      }
80      
81      private class YComparator implements Comparator<Point2D> {
82          /* (non-Javadoc)
83           * @see java.util.Comparator#compare(java.lang.Object, java.lang.Object)
84           */
85          public int compare(Point2D p1, Point2D p2){
86              if(p1.getY()<p2.getY())
87                  return -1;
88              if(p1.getY()>p2.getY())
89                  return +1;
90              return Double.compare(p1.getX(), p2.getX());
91          }
92      }
93     
94      private Node root;
95      
96      private Comparator<Point2D> xComparator;
97      private Comparator<Point2D> yComparator;
98      
99      public KDTree2D(ArrayList<Point2D> points) {
100         this.xComparator = new XComparator();
101         this.yComparator = new YComparator();
102         root = makeTree(points, 0);
103     }
104         
105     private Node makeTree(List<Point2D> points, int depth) {
106         // Add a leaf
107         if(points.size()==0)
108             return null;
109         
110         // select direction
111         int dir = depth%2;
112         
113         // sort points according to i-th dimension
114         if(dir==0){
115             // Compare points based on their x-coordinate
116             Collections.sort(points, xComparator);
117         }else{
118             // Compare points based on their x-coordinate
119             Collections.sort(points, yComparator);
120         }
121         
122         int n = points.size();
123         int med = n/2;    // compute median
124         
125         return new Node(
126                 points.get(med),
127                 makeTree(points.subList(0, med), depth+1),
128                 makeTree(points.subList(med+1, n), depth+1));
129     }
130 
131     public Node getRoot() {
132         return root;
133     }
134     
135     public boolean contains(Point2D value){
136         return contains(value, root, 0);
137     }
138     
139     private boolean contains(Point2D point, Node node, int depth){
140         if(node==null) return false;
141         
142         // select direction
143         int dir = depth%2;
144         
145         // sort points according to i-th dimension
146         int res;
147         if(dir==0){
148             // Compare points based on their x-coordinate
149             res = xComparator.compare(point, node.point);
150         }else{
151             // Compare points based on their x-coordinate
152             res = yComparator.compare(point, node.point);
153         }
154         
155         if(res<0)
156             return contains(point, node.left, depth+1);
157         if(res>0)
158             return contains(point, node.right, depth+1);
159         
160         return true;
161     }
162 
163     public Node getNode(Point2D point) {
164         return getNode(point, root, 0);
165     }
166     
167     private Node getNode(Point2D point, Node node, int depth){
168         if(node==null) return null;
169         // select direction
170         int dir = depth%2;
171         
172         // sort points according to i-th dimension
173         int res;
174         if(dir==0){
175             // Compare points based on their x-coordinate
176             res = xComparator.compare(point, node.point);
177         }else{
178             // Compare points based on their x-coordinate
179             res = yComparator.compare(point, node.point);
180         }
181         
182         if(res<0)
183             return getNode(point, node.left, depth+1);
184         if(res>0)
185             return getNode(point, node.right, depth+1);
186         
187         return node;
188     }
189 
190     public void add(Point2D point){
191        add(point, root, 0);
192     }
193     
194     private void add(Point2D point, Node node, int depth) {
195         // select direction
196         int dir = depth%2;
197         
198         // sort points according to i-th dimension
199         int res;
200         if(dir==0){
201             // Compare points based on their x-coordinate
202             res = xComparator.compare(point, node.point);
203         }else{
204             // Compare points based on their x-coordinate
205             res = yComparator.compare(point, node.point);
206         }
207         
208         if(res<0){
209             if(node.left==null)
210                 node.left = new Node(point);
211             else
212                 add(point, node.left, depth+1);
213         }
214         if(res>0)
215             if(node.right==null)
216                 node.right = new Node(point);
217             else
218                 add(point, node.right, depth+1);
219     }
220     
221     public Collection<Point2D> rangeSearch(Box2D range) {
222         ArrayList<Point2D> points = new ArrayList<Point2D>();
223         rangeSearch(range, points, root, 0);
224         return points;
225     }
226     
227     /**
228      * range search, by recursively adding points to the collection.
229      */
230     private void rangeSearch(Box2D range, 
231             Collection<Point2D> points, Node node, int depth) {
232         if(node==null)
233             return;
234         
235         // extract the point
236         Point2D point = node.getPoint();
237         double x = point.getX();
238         double y = point.getY();
239         
240         // check if point is in range
241         boolean tx1 = range.getMinX()<x;
242         boolean ty1 = range.getMinY()<y;
243         boolean tx2 = x<=range.getMaxX();
244         boolean ty2 = y<=range.getMaxY();
245         
246         // adds the point if it is present
247         if(tx1 && tx2 && ty1 && ty2)
248             points.add(point);
249         
250         // select direction
251         int dir = depth%2;
252         
253         if(dir==0 ? tx1 : ty1)
254             rangeSearch(range, points, node.left, depth+1);
255         if(dir==0 ? tx2 : ty2)
256             rangeSearch(range, points, node.right, depth+1);
257     }
258     
259     
260     public Point2D nearestNeighbor(Point2D point) {
261         return nearestNeighbor(point, root, root, 0).getPoint();
262     }
263     
264     /**
265      * Return either the same node as candidate, or another node whose point
266      * is closer.
267      */
268     private Node nearestNeighbor(Point2D point, Node candidate, Node node, 
269             int depth) {
270         // Check if the current node is closest that current candidate
271         double distCand = candidate.point.getDistance(point);
272         double dist     = node.point.getDistance(point);
273         if(dist<distCand){
274             candidate = node;
275         }
276         
277         // select direction
278         int dir = depth%2;
279 
280         Node node1, node2;
281         
282         // First try on the canonical side,
283         // the result is the closest node found by depth-firth search
284         Point2D anchor = node.getPoint();
285         StraightLine2D line;
286         if(dir==0){
287             boolean b = point.getX()<anchor.getX();
288             node1 = b ? node.left : node.right;
289             node2 = b ? node.right : node.left;
290            line = StraightLine2D.create(anchor, new Vector2D(0, 1)); 
291         } else {
292             boolean b = point.getY()<anchor.getY();
293             node1 = b ? node.left : node.right;
294             node2 = b ? node.right : node.left;
295             line = StraightLine2D.create(anchor, new Vector2D(1, 0)); 
296         }
297         
298         if(node1!=null) {
299             // Try to find a better candidate
300             candidate = nearestNeighbor(point, candidate, node1, depth+1);
301 
302             // recomputes distance to the (possibly new) candidate
303             distCand = candidate.getPoint().getDistance(point);
304         }
305         
306         // If line is close enough, there can be closer points to the other
307         // side of the line
308         if(line.getDistance(point)<distCand && node2!=null) {
309             candidate = nearestNeighbor(point, candidate, node2, depth+1);
310         }
311         
312         return candidate;
313     }
314     
315     
316     /**
317      * Gives a small example of use.
318      */
319     public static void main(String[] args){
320         int n = 3;
321         ArrayList<Point2D> points = new ArrayList<Point2D>(n);
322         points.add(new Point2D(5, 5));
323         points.add(new Point2D(10, 10));
324         points.add(new Point2D(20, 20));
325         
326         System.out.println("Check KDTree2D");
327         
328         KDTree2D tree = new KDTree2D(points);
329         
330         System.out.println(tree.contains(new Point2D(5, 5)));
331         System.out.println(tree.contains(new Point2D(6, 5)));
332     }
333     
334 }