diff options
-rw-r--r-- | data/cuts/tvt_test.py | 9 | ||||
-rw-r--r-- | data/cuts/tvt_valid.py | 9 | ||||
-rw-r--r-- | data/make_tvt.py | 188 |
3 files changed, 206 insertions, 0 deletions
diff --git a/data/cuts/tvt_test.py b/data/cuts/tvt_test.py new file mode 100644 index 0000000..cd97a2b --- /dev/null +++ b/data/cuts/tvt_test.py @@ -0,0 +1,9 @@ +import random + +begin = 1372636853 +end = 1404172787 + +random.seed(42) +cuts = [] +for i in range(500): + cuts.append(random.randrange(begin, end)) diff --git a/data/cuts/tvt_valid.py b/data/cuts/tvt_valid.py new file mode 100644 index 0000000..c5bb828 --- /dev/null +++ b/data/cuts/tvt_valid.py @@ -0,0 +1,9 @@ +import random + +begin = 1372636853 +end = 1404172787 + +random.seed(1337) +cuts = [] +for i in range(500): + cuts.append(random.randrange(begin, end)) diff --git a/data/make_tvt.py b/data/make_tvt.py new file mode 100644 index 0000000..c878f58 --- /dev/null +++ b/data/make_tvt.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python2 +# Separate the training set into a Training Valid and Test set + +import os +import sys +import importlib +import cPickle + +import h5py +import numpy +import theano + +import data +from data.hdf5 import TaxiDataset +from error import hdist + + +native_fields = { + 'trip_id': 'S19', + 'call_type': numpy.int8, + 'origin_call': numpy.int32, + 'origin_stand': numpy.int8, + 'taxi_id': numpy.int16, + 'timestamp': numpy.int32, + 'day_type': numpy.int8, + 'missing_data': numpy.bool, + 'latitude': data.Polyline, + 'longitude': data.Polyline, +} + +all_fields = { + 'path_len': numpy.int16, + 'cluster': numpy.int16, +} + +all_fields.update(native_fields) + +def cut_me_baby(train, cuts, excl={}): + dset = {} + cuts.sort() + cut_id = 0 + for i in xrange(data.train_size): + if i%10000==0 and i!=0: + print >> sys.stderr, 'cut: {:d} done'.format(i) + if i in excl: + continue + time = train['timestamp'][i] + latitude = train['latitude'][i] + longitude = train['longitude'][i] + + if len(latitude) == 0: + continue + + end_time = time + 15 * (len(latitude) - 1) + + while cuts[cut_id] < time: + if cut_id >= len(cuts)-1: + return dset + cut_id += 1 + + if end_time < cuts[cut_id]: + continue + else: + dset[i] = (cuts[cut_id] - time) / 15 + 1 + + return dset + +def make_tvt(test_cuts_name, valid_cuts_name, outpath): + trainset = TaxiDataset('train') + traindata = trainset.get_data(None, slice(0, trainset.num_examples)) + idsort = traindata[trainset.sources.index('timestamp')].argsort() + + traindata = dict(zip(trainset.sources, (t[idsort] for t in traindata))) + + print >> sys.stderr, 'test cut begin' + test_cuts = importlib.import_module('.%s' % test_cuts_name, 'data.cuts').cuts + test = cut_me_baby(traindata, test_cuts) + + print >> sys.stderr, 'valid cut begin' + valid_cuts = importlib.import_module('.%s' % valid_cuts_name, 'data.cuts').cuts + valid = cut_me_baby(traindata, valid_cuts, test) + + test_size = len(test) + valid_size = len(valid) + train_size = data.train_size - test_size - valid_size + + print ' set | size | ratio' + print ' ----- | ------- | -----' + print ' train | {:>7d} | {:>5.3f}'.format(train_size, float(train_size)/data.train_size) + print ' valid | {:>7d} | {:>5.3f}'.format(valid_size, float(valid_size)/data.train_size) + print ' test | {:>7d} | {:>5.3f}'.format(test_size , float(test_size )/data.train_size) + + with open(os.path.join(data.path, 'arrival-clusters.pkl'), 'r') as f: + clusters = cPickle.load(f) + + print >> sys.stderr, 'compiling cluster assignment function' + latitude = theano.tensor.scalar('latitude') + longitude = theano.tensor.scalar('longitude') + coords = theano.tensor.stack(latitude, longitude).dimshuffle('x', 0) + parent = theano.tensor.argmin(hdist(clusters, coords)) + cluster = theano.function([latitude, longitude], parent) + + print >> sys.stderr, 'preparing hdf5 data' + hdata = {k: numpy.empty(shape=(data.train_size,), dtype=v) for k, v in all_fields.iteritems()} + + train_i = 0 + valid_i = train_size + test_i = train_size + valid_size + + print >> sys.stderr, 'write: begin' + for idtraj in xrange(data.train_size): + if idtraj%10000==0 and idtraj!=0: + print >> sys.stderr, 'write: {:d} done'.format(idtraj) + in_test = idtraj in test + in_valid = not in_test and idtraj in valid + in_train = not in_test and not in_valid + + if idtraj in test: + i = test_i + test_i += 1 + elif idtraj in valid: + i = valid_i + valid_i += 1 + else: + i = train_i + train_i += 1 + + for field in native_fields: + val = traindata[field][idtraj] + if field in ['latitude', 'longitude']: + if in_test: + val = val[:test[idtraj]] + elif in_valid: + val = val[:valid[idtraj]] + hdata[field][i] = val + + plen = len(hdata['latitude'][i]) + hdata['path_len'][i] = plen + hdata['cluster'][i] = -1 if plen==0 else cluster(hdata['latitude'][i][0], hdata['longitude'][i][0]) + + print >> sys.stderr, 'write: end' + + print >> sys.stderr, 'preparing split array' + + split_array = numpy.empty(len(all_fields)*3, dtype=numpy.dtype([ + ('split', 'a', 64), + ('source', 'a', 21), + ('start', numpy.int64, 1), + ('stop', numpy.int64, 1), + ('indices', h5py.special_dtype(ref=h5py.Reference)), + ('available', numpy.bool, 1), + ('comment', 'a', 1)])) + + flen = len(all_fields) + for i, field in enumerate(all_fields): + split_array[i]['split'] = 'train'.encode('utf8') + split_array[i+flen]['split'] = 'valid'.encode('utf8') + split_array[i+2*flen]['split'] = 'test'.encode('utf8') + split_array[i]['start'] = 0 + split_array[i]['stop'] = train_size + split_array[i+flen]['start'] = train_size + split_array[i+flen]['stop'] = train_size + valid_size + split_array[i+2*flen]['start'] = train_size + valid_size + split_array[i+2*flen]['stop'] = train_size + valid_size + test_size + + for d in [0, flen, 2*flen]: + split_array[i+d]['source'] = field.encode('utf8') + + split_array[:]['indices'] = None + split_array[:]['available'] = True + split_array[:]['comment'] = '.'.encode('utf8') + + print >> sys.stderr, 'writing hdf5 file' + file = h5py.File(outpath, 'w') + for k in all_fields.keys(): + file.create_dataset(k, data=hdata[k], maxshape=(data.train_size,)) + + file.attrs['split'] = split_array + + file.flush() + file.close() + +if __name__ == '__main__': + if len(sys.argv) < 3 or len(sys.argv) > 4: + print >> sys.stderr, 'Usage: %s test_cutfile valid_cutfile [outfile]' % sys.argv[0] + sys.exit(1) + outpath = os.path.join(data.path, 'tvt.hdf5') if len(sys.argv) < 4 else sys.argv[3] + make_tvt(sys.argv[1], sys.argv[2], outpath) |