diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-28 15:57:35 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-04-28 15:57:35 -0400 |
commit | d58b121de641c0122652bc3d6096a9d0e1048391 (patch) | |
tree | 294e0e7bcf033c9c1c2bec9efdb5fcf6900c4ec1 /data.py | |
parent | 902a8dcb40b3da9492093edd5bda356240f29eb0 (diff) | |
download | taxi-d58b121de641c0122652bc3d6096a9d0e1048391.tar.gz taxi-d58b121de641c0122652bc3d6096a9d0e1048391.zip |
Add function for applying model
Diffstat (limited to 'data.py')
-rw-r--r-- | data.py | 12 |
1 files changed, 9 insertions, 3 deletions
@@ -6,6 +6,7 @@ from enum import Enum from fuel.datasets import Dataset from fuel.streams import DataStream from fuel.iterator import DataIterator +import theano if socket.gethostname() == "adeb.laptop": DATA_PATH = "/Users/adeb/data/taxi" @@ -14,8 +15,8 @@ else: client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))} -porto_center = numpy.array([[ -8.61612, 41.1573]], dtype='float32') -data_std = numpy.sqrt(numpy.array([[ 0.00333233, 0.00549598]], dtype='float32')) +porto_center = numpy.array([[ -8.61612, 41.1573]], dtype=theano.config.floatX) +data_std = numpy.sqrt(numpy.array([[ 0.00333233, 0.00549598]], dtype=theano.config.floatX)) class CallType(Enum): CENTRAL = 0 @@ -143,8 +144,13 @@ taxi_columns_valid = taxi_columns + [ train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)] valid_files=["%s/split/valid.csv" % (DATA_PATH,)] + train_data=TaxiData(train_files, taxi_columns) -valid_data=TaxiData(valid_files, taxi_columns_valid) + +valid_data = TaxiData(valid_files, taxi_columns_valid) + +# for the moment - will be changed later +test_data = valid_data def train_it(): return DataIterator(DataStream(train_data)) |