diff options
Diffstat (limited to 'data/make_valid_cut.py')
-rwxr-xr-x | data/make_valid_cut.py | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/data/make_valid_cut.py b/data/make_valid_cut.py new file mode 100755 index 0000000..d5be083 --- /dev/null +++ b/data/make_valid_cut.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# Make a valid dataset by cutting the training set at specified timestamps + +import os +import sys +import importlib + +import h5py +import numpy + +import data +from data.hdf5 import taxi_it + + +_fields = ['trip_id', 'call_type', 'origin_call', 'origin_stand', 'taxi_id', 'timestamp', 'day_type', 'missing_data', 'latitude', 'longitude', 'destination_latitude', 'destination_longitude', 'travel_time'] + +def make_valid(cutfile, outpath): + cuts = importlib.import_module('.%s' % cutfile, 'data.cuts').cuts + + valid = [] + + for line in taxi_it('train'): + time = line['timestamp'] + latitude = line['latitude'] + longitude = line['longitude'] + + if len(latitude) == 0: + continue + + for ts in cuts: + if time <= ts and time + 15 * (len(latitude) - 1) >= ts: + # keep it + n = (ts - time) / 15 + 1 + line.update({ + 'latitude': latitude[:n], + 'longitude': longitude[:n], + 'destination_latitude': latitude[-1], + 'destination_longitude': longitude[-1], + 'travel_time': 15 * (len(latitude)-1) + }) + valid.append(line) + + file = h5py.File(outpath, 'a') + clen = file['trip_id'].shape[0] + alen = len(valid) + for field in _fields: + dset = file[field] + dset.resize((clen + alen,)) + for i in xrange(alen): + dset[clen + i] = valid[i][field] + + splits = file.attrs['split'] + slen = splits.shape[0] + splits = numpy.resize(splits, (slen+len(_fields),)) + for (i, field) in enumerate(_fields): + splits[slen+i]['split'] = ('cuts/%s' % cutfile).encode('utf8') + splits[slen+i]['source'] = field.encode('utf8') + splits[slen+i]['start'] = clen + splits[slen+i]['stop'] = alen + splits[slen+i]['available'] = True + splits[slen+i]['comment'] = '.' + file.attrs['split'] = splits + + file.flush() + file.close() + +if __name__ == '__main__': + if len(sys.argv) < 2 or len(sys.argv) > 3: + print >> sys.stderr, 'Usage: %s cutfile [outfile]' % sys.argv[0] + sys.exit(1) + outpath = os.path.join(data.path, 'valid.hdf5') if len(sys.argv) < 3 else sys.argv[2] + make_valid(sys.argv[1], outpath) |