diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-07-23 19:24:00 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-07-23 19:24:00 -0400 |
commit | 206e7c17aca14d0c48cb59f4bb8b3607279aeaba (patch) | |
tree | de3dfc9cbee2b77e99f3cb988656a10a933006a1 | |
parent | 1795dfe742bcb75085a909413b723b64a8eeb4fc (diff) | |
download | taxi-206e7c17aca14d0c48cb59f4bb8b3607279aeaba.tar.gz taxi-206e7c17aca14d0c48cb59f4bb8b3607279aeaba.zip |
Remove useless call_origin and bug fix in tvt
-rwxr-xr-x | data/make_tvt.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/data/make_tvt.py b/data/make_tvt.py index 983eb0f..a3c1de6 100755 --- a/data/make_tvt.py +++ b/data/make_tvt.py @@ -103,6 +103,8 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath): parent = theano.tensor.argmin(hdist(clusters, coords)) cluster = theano.function([latitude, longitude], parent) + train_clients = set() + print >> sys.stderr, 'preparing hdf5 data' hdata = {k: numpy.empty(shape=(data.train_size,), dtype=v) for k, v in all_fields.iteritems()} @@ -125,17 +127,18 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath): i = valid_i valid_i += 1 else: + train_clients.add(traindata['origin_call'][idtraj]) i = train_i train_i += 1 trajlen = len(traindata['latitude'][idtraj]) if trajlen == 0: - hdata['destination_latitude'] = data.train_gps_mean[0] - hdata['destination_longitude'] = data.train_gps_mean[1] + hdata['destination_latitude'][i] = data.train_gps_mean[0] + hdata['destination_longitude'][i] = data.train_gps_mean[1] else: - hdata['destination_latitude'] = traindata['latitude'][idtraj][-1] - hdata['destination_longitude'] = traindata['longitude'][idtraj][-1] - hdata['travel_time'] = trajlen + hdata['destination_latitude'][i] = traindata['latitude'][idtraj][-1] + hdata['destination_longitude'][i] = traindata['longitude'][idtraj][-1] + hdata['travel_time'][i] = trajlen for field in native_fields: val = traindata[field][idtraj] @@ -152,6 +155,11 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath): print >> sys.stderr, 'write: end' + print >> sys.stderr, 'removing useless origin_call' + for i in xrange(train_size, data.train_size): + if hdata['origin_call'][i] not in train_clients: + hdata['origin_call'][i] = 0 + print >> sys.stderr, 'preparing split array' split_array = numpy.empty(len(all_fields)*3, dtype=numpy.dtype([ |