aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-23 19:24:00 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-23 19:24:00 -0400
commit206e7c17aca14d0c48cb59f4bb8b3607279aeaba (patch)
treede3dfc9cbee2b77e99f3cb988656a10a933006a1
parent1795dfe742bcb75085a909413b723b64a8eeb4fc (diff)
downloadtaxi-206e7c17aca14d0c48cb59f4bb8b3607279aeaba.tar.gz
taxi-206e7c17aca14d0c48cb59f4bb8b3607279aeaba.zip
Remove useless call_origin and bug fix in tvt
-rwxr-xr-xdata/make_tvt.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/data/make_tvt.py b/data/make_tvt.py
index 983eb0f..a3c1de6 100755
--- a/data/make_tvt.py
+++ b/data/make_tvt.py
@@ -103,6 +103,8 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath):
parent = theano.tensor.argmin(hdist(clusters, coords))
cluster = theano.function([latitude, longitude], parent)
+ train_clients = set()
+
print >> sys.stderr, 'preparing hdf5 data'
hdata = {k: numpy.empty(shape=(data.train_size,), dtype=v) for k, v in all_fields.iteritems()}
@@ -125,17 +127,18 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath):
i = valid_i
valid_i += 1
else:
+ train_clients.add(traindata['origin_call'][idtraj])
i = train_i
train_i += 1
trajlen = len(traindata['latitude'][idtraj])
if trajlen == 0:
- hdata['destination_latitude'] = data.train_gps_mean[0]
- hdata['destination_longitude'] = data.train_gps_mean[1]
+ hdata['destination_latitude'][i] = data.train_gps_mean[0]
+ hdata['destination_longitude'][i] = data.train_gps_mean[1]
else:
- hdata['destination_latitude'] = traindata['latitude'][idtraj][-1]
- hdata['destination_longitude'] = traindata['longitude'][idtraj][-1]
- hdata['travel_time'] = trajlen
+ hdata['destination_latitude'][i] = traindata['latitude'][idtraj][-1]
+ hdata['destination_longitude'][i] = traindata['longitude'][idtraj][-1]
+ hdata['travel_time'][i] = trajlen
for field in native_fields:
val = traindata[field][idtraj]
@@ -152,6 +155,11 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath):
print >> sys.stderr, 'write: end'
+ print >> sys.stderr, 'removing useless origin_call'
+ for i in xrange(train_size, data.train_size):
+ if hdata['origin_call'][i] not in train_clients:
+ hdata['origin_call'][i] = 0
+
print >> sys.stderr, 'preparing split array'
split_array = numpy.empty(len(all_fields)*3, dtype=numpy.dtype([