diff options
author | Étienne Simon <esimon@esimon.eu> | 2015-05-06 11:54:52 -0400 |
---|---|---|
committer | Étienne Simon <esimon@esimon.eu> | 2015-05-06 11:55:26 -0400 |
commit | 0b4b65cb3d88ac4818e71ccef0bded3ddee0683c (patch) | |
tree | 75c0942dd20e73a6c4a67ab207e916f83c7d4bae /data/__init__.py | |
parent | 35b4503ddd148b0c937468891dd0a7e9ff1c79f4 (diff) | |
download | taxi-0b4b65cb3d88ac4818e71ccef0bded3ddee0683c.tar.gz taxi-0b4b65cb3d88ac4818e71ccef0bded3ddee0683c.zip |
Fix floatX!=float32 in hdf5 creation
Diffstat (limited to 'data/__init__.py')
-rw-r--r-- | data/__init__.py | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/data/__init__.py b/data/__init__.py index 1278e0b..2121033 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -2,12 +2,11 @@ import os import h5py import numpy -import theano path = os.environ.get('TAXI_PATH', '/data/lisatmp3/auvolat/taxikaggle') -Polyline = h5py.special_dtype(vlen=theano.config.floatX) +Polyline = h5py.special_dtype(vlen=numpy.float32) # `wc -l test.csv` - 1 # Minus 1 to ignore the header test_size = 320 @@ -27,5 +26,5 @@ origin_call_size = 57125 # include 0 ("no origin_call") # As printed by csv_to_hdf5.py origin_call_train_size = 57106 -train_gps_mean = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX) -train_gps_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX)) +train_gps_mean = numpy.array([41.1573, -8.61612], dtype=numpy.float32) +train_gps_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=numpy.float32)) |