aboutsummaryrefslogtreecommitdiff
path: root/data/make_valid_cut.py
diff options
context:
space:
mode:
authorAdeB <adbrebs@gmail.com>2015-05-05 22:15:29 -0400
committerAdeB <adbrebs@gmail.com>2015-05-05 22:15:29 -0400
commitc29a0d3f22134a8d1f5d557b325f6779c5961546 (patch)
tree6fa431d9b3595b5d2d11089920aa07ea43172d90 /data/make_valid_cut.py
parentf4d3ee6449217535bdbe19ac9c5fdd825d71b0d3 (diff)
parent1f2ff96e6480a62089fcac35154a956c218ed678 (diff)
downloadtaxi-c29a0d3f22134a8d1f5d557b325f6779c5961546.tar.gz
taxi-c29a0d3f22134a8d1f5d557b325f6779c5961546.zip
Merge branch 'master' of github.com:adbrebs/taxi
Diffstat (limited to 'data/make_valid_cut.py')
-rwxr-xr-xdata/make_valid_cut.py72
1 files changed, 72 insertions, 0 deletions
diff --git a/data/make_valid_cut.py b/data/make_valid_cut.py
new file mode 100755
index 0000000..d5be083
--- /dev/null
+++ b/data/make_valid_cut.py
@@ -0,0 +1,72 @@
+#!/usr/bin/env python
+# Make a valid dataset by cutting the training set at specified timestamps
+
+import os
+import sys
+import importlib
+
+import h5py
+import numpy
+
+import data
+from data.hdf5 import taxi_it
+
+
+_fields = ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude', 'destination_latitude', 'destination_longitude', 'travel_time']
+
+def make_valid(cutfile, outpath):
+ cuts = importlib.import_module('.%s' % cutfile, 'data.cuts').cuts
+
+ valid = []
+
+ for line in taxi_it('train'):
+ time = line['timestamp']
+ latitude = line['latitude']
+ longitude = line['longitude']
+
+ if len(latitude) == 0:
+ continue
+
+ for ts in cuts:
+ if time <= ts and time + 15 * (len(latitude) - 1) >= ts:
+ # keep it
+ n = (ts - time) / 15 + 1
+ line.update({
+ 'latitude': latitude[:n],
+ 'longitude': longitude[:n],
+ 'destination_latitude': latitude[-1],
+ 'destination_longitude': longitude[-1],
+ 'travel_time': 15 * (len(latitude)-1)
+ })
+ valid.append(line)
+
+ file = h5py.File(outpath, 'a')
+ clen = file['trip_id'].shape[0]
+ alen = len(valid)
+ for field in _fields:
+ dset = file[field]
+ dset.resize((clen + alen,))
+ for i in xrange(alen):
+ dset[clen + i] = valid[i][field]
+
+ splits = file.attrs['split']
+ slen = splits.shape[0]
+ splits = numpy.resize(splits, (slen+len(_fields),))
+ for (i, field) in enumerate(_fields):
+ splits[slen+i]['split'] = ('cuts/%s' % cutfile).encode('utf8')
+ splits[slen+i]['source'] = field.encode('utf8')
+ splits[slen+i]['start'] = clen
+ splits[slen+i]['stop'] = alen
+ splits[slen+i]['available'] = True
+ splits[slen+i]['comment'] = '.'
+ file.attrs['split'] = splits
+
+ file.flush()
+ file.close()
+
+if __name__ == '__main__':
+ if len(sys.argv) < 2 or len(sys.argv) > 3:
+ print >> sys.stderr, 'Usage: %s cutfile [outfile]' % sys.argv[0]
+ sys.exit(1)
+ outpath = os.path.join(data.path, 'valid.hdf5') if len(sys.argv) < 3 else sys.argv[2]
+ make_valid(sys.argv[1], outpath)