aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-23 18:34:51 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-23 18:34:51 -0400
commitfe608831c62c7dba60a3bf57433d97b999e567c8 (patch)
tree1f54fa71b4b61e78868344d36fb7715bb7109649
parente7aba08e6b209ac7f091eb9f08b49a2c90b070ed (diff)
downloadtaxi-fe608831c62c7dba60a3bf57433d97b999e567c8.tar.gz
taxi-fe608831c62c7dba60a3bf57433d97b999e567c8.zip
Fix tvt hdf5
-rwxr-xr-x[-rw-r--r--]data/make_tvt.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/data/make_tvt.py b/data/make_tvt.py
index c878f58..983eb0f 100644..100755
--- a/data/make_tvt.py
+++ b/data/make_tvt.py
@@ -31,6 +31,9 @@ native_fields = {
all_fields = {
'path_len': numpy.int16,
'cluster': numpy.int16,
+ 'destination_latitude': numpy.float32,
+ 'destination_longitude': numpy.float32,
+ 'travel_time': numpy.int32,
}
all_fields.update(native_fields)
@@ -125,6 +128,15 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath):
i = train_i
train_i += 1
+ trajlen = len(traindata['latitude'][idtraj])
+ if trajlen == 0:
+ hdata['destination_latitude'] = data.train_gps_mean[0]
+ hdata['destination_longitude'] = data.train_gps_mean[1]
+ else:
+ hdata['destination_latitude'] = traindata['latitude'][idtraj][-1]
+ hdata['destination_longitude'] = traindata['longitude'][idtraj][-1]
+ hdata['travel_time'] = trajlen
+
for field in native_fields:
val = traindata[field][idtraj]
if field in ['latitude', 'longitude']: