aboutsummaryrefslogtreecommitdiff
path: root/data_analysis
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-10 17:16:26 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-10 17:16:26 -0400
commit788747ec066234c16643595253ebe4d6bfeebe74 (patch)
tree00c89704bfcfa98abadff643ecab5f764a50ba79 /data_analysis
parent27a0e0949c6ca3f7bd18569a23ddd0e1b3e9a64e (diff)
downloadtaxi-788747ec066234c16643595253ebe4d6bfeebe74.tar.gz
taxi-788747ec066234c16643595253ebe4d6bfeebe74.zip
Adjust cluster_arrival.py to make it work again
Diffstat (limited to 'data_analysis')
-rwxr-xr-x[-rw-r--r--]data_analysis/cluster_arrival.py23
1 files changed, 17 insertions, 6 deletions
diff --git a/data_analysis/cluster_arrival.py b/data_analysis/cluster_arrival.py
index fd4ea04..5e990cd 100644..100755
--- a/data_analysis/cluster_arrival.py
+++ b/data_analysis/cluster_arrival.py
@@ -1,20 +1,31 @@
-import matplotlib.pyplot as plt
+#!/usr/bin/env python
import numpy
import cPickle
import scipy.misc
+import os
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
from itertools import cycle
-print "Reading arrival point list"
-with open("arrivals.pkl") as f:
- pts = cPickle.load(f)
+import data
+from data.hdf5 import taxi_it
+from data.transformers import add_destination
+
+print "Generating arrival point list"
+dests = []
+for v in taxi_it("train"):
+ if len(v['latitude']) == 0: continue
+ dests.append([v['latitude'][-1], v['longitude'][-1]])
+pts = numpy.array(dests)
+
+with open(os.path.join(data.path, "arrivals.pkl"), "w") as f:
+ cPickle.dump(pts, f, protocol=cPickle.HIGHEST_PROTOCOL)
print "Doing clustering"
bw = estimate_bandwidth(pts, quantile=.1, n_samples=1000)
print bw
-bw = 0.001
+bw = 0.001 # (
ms = MeanShift(bandwidth=bw, bin_seeding=True, min_bin_freq=5)
ms.fit(pts)
@@ -22,6 +33,6 @@ cluster_centers = ms.cluster_centers_
print "Clusters shape: ", cluster_centers.shape
-with open("arrival-cluters.pkl", "w") as f:
+with open(os.path.join(data.path, "arrival-clusters.pkl"), "w") as f:
cPickle.dump(cluster_centers, f, protocol=cPickle.HIGHEST_PROTOCOL)