diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-07-23 18:34:51 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-07-23 18:34:51 -0400 |
commit | fe608831c62c7dba60a3bf57433d97b999e567c8 (patch) | |
tree | 1f54fa71b4b61e78868344d36fb7715bb7109649 /data | |
parent | e7aba08e6b209ac7f091eb9f08b49a2c90b070ed (diff) | |
download | taxi-fe608831c62c7dba60a3bf57433d97b999e567c8.tar.gz taxi-fe608831c62c7dba60a3bf57433d97b999e567c8.zip |
Fix tvt hdf5
Diffstat (limited to 'data')
-rwxr-xr-x[-rw-r--r--] | data/make_tvt.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/data/make_tvt.py b/data/make_tvt.py index c878f58..983eb0f 100644..100755 --- a/data/make_tvt.py +++ b/data/make_tvt.py @@ -31,6 +31,9 @@ native_fields = { all_fields = { 'path_len': numpy.int16, 'cluster': numpy.int16, + 'destination_latitude': numpy.float32, + 'destination_longitude': numpy.float32, + 'travel_time': numpy.int32, } all_fields.update(native_fields) @@ -125,6 +128,15 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath): i = train_i train_i += 1 + trajlen = len(traindata['latitude'][idtraj]) + if trajlen == 0: + hdata['destination_latitude'] = data.train_gps_mean[0] + hdata['destination_longitude'] = data.train_gps_mean[1] + else: + hdata['destination_latitude'] = traindata['latitude'][idtraj][-1] + hdata['destination_longitude'] = traindata['longitude'][idtraj][-1] + hdata['travel_time'] = trajlen + for field in native_fields: val = traindata[field][idtraj] if field in ['latitude', 'longitude']: |