aboutsummaryrefslogblamecommitdiff
path: root/data/init_valid.py
blob: 14a854c1f0ab796272b9539ea6af81262a81042d (plain) (tree)




























































                                                                                            
#!/usr/bin/env python
# Initialize the valid hdf5

import os
import sys

import h5py
import numpy
import theano

import data


_fields = {
    'trip_id': 'S19',
    'call_type': numpy.uint8,
    'origin_call': numpy.uint32,
    'origin_stand': numpy.uint8,
    'taxi_id': numpy.uint16,
    'timestamp': numpy.uint32,
    'day_type': numpy.uint8,
    'missing_data': numpy.bool,
    'latitude': data.Polyline,
    'longitude': data.Polyline,
    'destination_latitude': theano.config.floatX,
    'destination_longitude': theano.config.floatX,
    'travel_time': numpy.uint32,
}


def init_valid(path):
    h5file = h5py.File(path, 'w')
    
    for k, v in _fields.items():
        h5file.create_dataset(k, (0,), dtype=v, maxshape=(None,))

    split_array = numpy.empty(len(_fields), dtype=numpy.dtype([
        ('split', 'a', 64),
        ('source', 'a', 21),
        ('start', numpy.int64, 1),
        ('stop', numpy.int64, 1),
        ('available', numpy.bool, 1),
        ('comment', 'a', 1)]))

    split_array[:]['split'] = 'dummy'.encode('utf8')
    for (i, k) in enumerate(_fields.keys()):
        split_array[i] = k.encode('utf8')
    split_array[:]['start'] = 0
    split_array[:]['stop'] = 0
    split_array[:]['available'] = False
    split_array[:]['comment'] = '.'.encode('utf8')
    h5file.attrs['split'] = split_array

    h5file.flush()
    h5file.close()

if __name__ == '__main__':
    if len(sys.argv) > 2:
        print >> sys.stderr, 'Usage: %s [file]' % sys.argv[0]
        sys.exit(1)
    init_valid(sys.argv[1] if len(sys.argv) == 2 else os.path.join(data.path, 'valid.hdf5'))