diff options
-rwxr-xr-x[-rw-r--r--] | data/make_tvt.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/data/make_tvt.py b/data/make_tvt.py index c878f58..983eb0f 100644..100755 --- a/data/make_tvt.py +++ b/data/make_tvt.py @@ -31,6 +31,9 @@ native_fields = { all_fields = { 'path_len': numpy.int16, 'cluster': numpy.int16, + 'destination_latitude': numpy.float32, + 'destination_longitude': numpy.float32, + 'travel_time': numpy.int32, } all_fields.update(native_fields) @@ -125,6 +128,15 @@ def make_tvt(test_cuts_name, valid_cuts_name, outpath): i = train_i train_i += 1 + trajlen = len(traindata['latitude'][idtraj]) + if trajlen == 0: + hdata['destination_latitude'] = data.train_gps_mean[0] + hdata['destination_longitude'] = data.train_gps_mean[1] + else: + hdata['destination_latitude'] = traindata['latitude'][idtraj][-1] + hdata['destination_longitude'] = traindata['longitude'][idtraj][-1] + hdata['travel_time'] = trajlen + for field in native_fields: val = traindata[field][idtraj] if field in ['latitude', 'longitude']: |