diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 16:43:48 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-05-04 16:58:31 -0400 |
commit | 80d3ea67a845484d119cb88f0a0412f981ab344c (patch) | |
tree | 37b8130b6d761bcda48c8c0f74114498b85dad97 /data_analysis/cluster_arrival.py | |
parent | f9a31bd246e3c4736d3f532b566b7437eba6b4de (diff) | |
download | taxi-80d3ea67a845484d119cb88f0a0412f981ab344c.tar.gz taxi-80d3ea67a845484d119cb88f0a0412f981ab344c.zip |
Mew data analysis tool: clustering of arrival points.
Diffstat (limited to 'data_analysis/cluster_arrival.py')
-rw-r--r-- | data_analysis/cluster_arrival.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/data_analysis/cluster_arrival.py b/data_analysis/cluster_arrival.py new file mode 100644 index 0000000..fd4ea04 --- /dev/null +++ b/data_analysis/cluster_arrival.py @@ -0,0 +1,27 @@ +import matplotlib.pyplot as plt +import numpy +import cPickle +import scipy.misc + +from sklearn.cluster import MeanShift, estimate_bandwidth +from sklearn.datasets.samples_generator import make_blobs +from itertools import cycle + +print "Reading arrival point list" +with open("arrivals.pkl") as f: + pts = cPickle.load(f) + +print "Doing clustering" +bw = estimate_bandwidth(pts, quantile=.1, n_samples=1000) +print bw +bw = 0.001 + +ms = MeanShift(bandwidth=bw, bin_seeding=True, min_bin_freq=5) +ms.fit(pts) +cluster_centers = ms.cluster_centers_ + +print "Clusters shape: ", cluster_centers.shape + +with open("arrival-cluters.pkl", "w") as f: + cPickle.dump(cluster_centers, f, protocol=cPickle.HIGHEST_PROTOCOL) + |