aboutsummaryrefslogtreecommitdiff
path: root/data.py
diff options
context:
space:
mode:
Diffstat (limited to 'data.py')
-rw-r--r--data.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/data.py b/data.py
index 7708863..f1236a5 100644
--- a/data.py
+++ b/data.py
@@ -6,6 +6,7 @@ from enum import Enum
from fuel.datasets import Dataset
from fuel.streams import DataStream
from fuel.iterator import DataIterator
+import theano
if socket.gethostname() == "adeb.laptop":
DATA_PATH = "/Users/adeb/data/taxi"
@@ -14,8 +15,8 @@ else:
client_ids = {int(x): y+1 for y, x in enumerate(open(DATA_PATH+"/client_ids.txt"))}
-porto_center = numpy.array([[ -8.61612, 41.1573]], dtype='float32')
-data_std = numpy.sqrt(numpy.array([[ 0.00333233, 0.00549598]], dtype='float32'))
+porto_center = numpy.array([[ -8.61612, 41.1573]], dtype=theano.config.floatX)
+data_std = numpy.sqrt(numpy.array([[ 0.00333233, 0.00549598]], dtype=theano.config.floatX))
class CallType(Enum):
CENTRAL = 0
@@ -143,8 +144,13 @@ taxi_columns_valid = taxi_columns + [
train_files=["%s/split/train-%02d.csv" % (DATA_PATH, i) for i in range(100)]
valid_files=["%s/split/valid.csv" % (DATA_PATH,)]
+
train_data=TaxiData(train_files, taxi_columns)
-valid_data=TaxiData(valid_files, taxi_columns_valid)
+
+valid_data = TaxiData(valid_files, taxi_columns_valid)
+
+# for the moment - will be changed later
+test_data = valid_data
def train_it():
return DataIterator(DataStream(train_data))