aboutsummaryrefslogtreecommitdiff
path: root/data_analysis/cluster_arrival.py
blob: 5e990cdb7877773e02ae6d880d1ae3c3b90485cb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#!/usr/bin/env python
import numpy
import cPickle
import scipy.misc
import os

from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
from itertools import cycle

import data
from data.hdf5 import taxi_it
from data.transformers import add_destination

print "Generating arrival point list"
dests = []
for v in taxi_it("train"):
    if len(v['latitude']) == 0: continue
    dests.append([v['latitude'][-1], v['longitude'][-1]])
pts = numpy.array(dests)

with open(os.path.join(data.path, "arrivals.pkl"), "w") as f:
    cPickle.dump(pts, f, protocol=cPickle.HIGHEST_PROTOCOL)

print "Doing clustering"
bw = estimate_bandwidth(pts, quantile=.1, n_samples=1000)
print bw
bw = 0.001 # (

ms = MeanShift(bandwidth=bw, bin_seeding=True, min_bin_freq=5)
ms.fit(pts)
cluster_centers = ms.cluster_centers_

print "Clusters shape: ", cluster_centers.shape

with open(os.path.join(data.path, "arrival-clusters.pkl"), "w") as f:
    cPickle.dump(cluster_centers, f, protocol=cPickle.HIGHEST_PROTOCOL)