1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
#!/usr/bin/env python
import numpy
import cPickle
import scipy.misc
import os
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs
from itertools import cycle
import data
from data.hdf5 import taxi_it
from data.transformers import add_destination
print "Generating arrival point list"
dests = []
for v in taxi_it("train"):
if len(v['latitude']) == 0: continue
dests.append([v['latitude'][-1], v['longitude'][-1]])
pts = numpy.array(dests)
with open(os.path.join(data.path, "arrivals.pkl"), "w") as f:
cPickle.dump(pts, f, protocol=cPickle.HIGHEST_PROTOCOL)
print "Doing clustering"
bw = estimate_bandwidth(pts, quantile=.1, n_samples=1000)
print bw
bw = 0.001 # (
ms = MeanShift(bandwidth=bw, bin_seeding=True, min_bin_freq=5)
ms.fit(pts)
cluster_centers = ms.cluster_centers_
print "Clusters shape: ", cluster_centers.shape
with open(os.path.join(data.path, "arrival-clusters.pkl"), "w") as f:
cPickle.dump(cluster_centers, f, protocol=cPickle.HIGHEST_PROTOCOL)
|