From fe608831c62c7dba60a3bf57433d97b999e567c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Thu, 23 Jul 2015 18:34:51 -0400 Subject: Fix tvt hdf5 --- data/make_tvt.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) mode change 100644 => 100755 data/make_tvt.py (limited to 'data/make_tvt.py') diff --git a/data/make_tvt.py b/data/make_tvt.py old mode 100644 new mode 100755 index c878f58..983eb0f --- a/data/make_tvt.py +++ b/data/make_tvt.py @@ -31,6 +31,9 @@ native_fields = { all_fields = { 'path_len': numpy.int16, 'cluster': numpy.int16, + 'destination_latitude': numpy.float32, + 'destination_longitude': numpy.float32, + 'travel_time': numpy.int32, } all_fields.update(native_fields) @@ -125,6 +128,15 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath): i = train_i train_i += 1 + trajlen = len(traindata['latitude'][idtraj]) + if trajlen == 0: + hdata['destination_latitude'] = data.train_gps_mean[0] + hdata['destination_longitude'] = data.train_gps_mean[1] + else: + hdata['destination_latitude'] = traindata['latitude'][idtraj][-1] + hdata['destination_longitude'] = traindata['longitude'][idtraj][-1] + hdata['travel_time'] = trajlen + for field in native_fields: val = traindata[field][idtraj] if field in ['latitude', 'longitude']: -- cgit v1.2.3