aboutsummaryrefslogtreecommitdiff
path: root/data/__init__.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-05-06 11:54:52 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-05-06 11:55:26 -0400
commit0b4b65cb3d88ac4818e71ccef0bded3ddee0683c (patch)
tree75c0942dd20e73a6c4a67ab207e916f83c7d4bae /data/__init__.py
parent35b4503ddd148b0c937468891dd0a7e9ff1c79f4 (diff)
downloadtaxi-0b4b65cb3d88ac4818e71ccef0bded3ddee0683c.tar.gz
taxi-0b4b65cb3d88ac4818e71ccef0bded3ddee0683c.zip
Fix floatX!=float32 in hdf5 creation
Diffstat (limited to 'data/__init__.py')
-rw-r--r--data/__init__.py7
1 files changed, 3 insertions, 4 deletions
diff --git a/data/__init__.py b/data/__init__.py
index 1278e0b..2121033 100644
--- a/data/__init__.py
+++ b/data/__init__.py
@@ -2,12 +2,11 @@ import os
import h5py
import numpy
-import theano
path = os.environ.get('TAXI_PATH', '/data/lisatmp3/auvolat/taxikaggle')
-Polyline = h5py.special_dtype(vlen=theano.config.floatX)
+Polyline = h5py.special_dtype(vlen=numpy.float32)
# `wc -l test.csv` - 1 # Minus 1 to ignore the header
test_size = 320
@@ -27,5 +26,5 @@ origin_call_size = 57125 # include 0 ("no origin_call")
# As printed by csv_to_hdf5.py
origin_call_train_size = 57106
-train_gps_mean = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX)
-train_gps_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX))
+train_gps_mean = numpy.array([41.1573, -8.61612], dtype=numpy.float32)
+train_gps_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=numpy.float32))