aboutsummaryrefslogtreecommitdiff
path: root/data_analysis/cluster_arrival.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-05-04 16:43:48 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-05-04 16:58:31 -0400
commit80d3ea67a845484d119cb88f0a0412f981ab344c (patch)
tree37b8130b6d761bcda48c8c0f74114498b85dad97 /data_analysis/cluster_arrival.py
parentf9a31bd246e3c4736d3f532b566b7437eba6b4de (diff)
downloadtaxi-80d3ea67a845484d119cb88f0a0412f981ab344c.tar.gz
taxi-80d3ea67a845484d119cb88f0a0412f981ab344c.zip
Mew data analysis tool: clustering of arrival points.
Diffstat (limited to 'data_analysis/cluster_arrival.py')
-rw-r--r--data_analysis/cluster_arrival.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/data_analysis/cluster_arrival.py b/data_analysis/cluster_arrival.py
new file mode 100644
index 0000000..fd4ea04
--- /dev/null
+++ b/data_analysis/cluster_arrival.py
@@ -0,0 +1,27 @@
+import matplotlib.pyplot as plt
+import numpy
+import cPickle
+import scipy.misc
+
+from sklearn.cluster import MeanShift, estimate_bandwidth
+from sklearn.datasets.samples_generator import make_blobs
+from itertools import cycle
+
+print "Reading arrival point list"
+with open("arrivals.pkl") as f:
+ pts = cPickle.load(f)
+
+print "Doing clustering"
+bw = estimate_bandwidth(pts, quantile=.1, n_samples=1000)
+print bw
+bw = 0.001
+
+ms = MeanShift(bandwidth=bw, bin_seeding=True, min_bin_freq=5)
+ms.fit(pts)
+cluster_centers = ms.cluster_centers_
+
+print "Clusters shape: ", cluster_centers.shape
+
+with open("arrival-cluters.pkl", "w") as f:
+ cPickle.dump(cluster_centers, f, protocol=cPickle.HIGHEST_PROTOCOL)
+