diff options
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 12 |
1 files changed, 9 insertions, 3 deletions
@@ -6,6 +6,7 @@ from enum import Enum from fuel.datasets import Dataset from fuel.streams import DataStream from fuel.iterator import DataIterator +import theano if socket.gethostname() == "adeb.laptop": DATA_PATH = "/Users/adeb/data/taxi" @@ -14,8 +15,8 @@ else: client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))} -porto_center = numpy.array([[ -8.61612, 41.1573]], dtype='float32') -data_std = numpy.sqrt(numpy.array([[ 0.00333233, 0.00549598]], dtype='float32')) +porto_center = numpy.array([[ -8.61612, 41.1573]], dtype=theano.config.floatX) +data_std = numpy.sqrt(numpy.array([[ 0.00333233, 0.00549598]], dtype=theano.config.floatX)) class CallType(Enum): CENTRAL = 0 @@ -143,8 +144,13 @@ taxi_columns_valid = taxi_columns + [ train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)] valid_files=["%s/split/valid.csv" % (DATA_PATH,)] + train_data=TaxiData(train_files, taxi_columns) -valid_data=TaxiData(valid_files, taxi_columns_valid) + +valid_data = TaxiData(valid_files, taxi_columns_valid) + +# for the moment - will be changed later +test_data = valid_data def train_it(): return DataIterator(DataStream(train_data)) |