aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-05-06 11:54:52 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-05-06 11:55:26 -0400
commit0b4b65cb3d88ac4818e71ccef0bded3ddee0683c (patch)
tree75c0942dd20e73a6c4a67ab207e916f83c7d4bae
parent35b4503ddd148b0c937468891dd0a7e9ff1c79f4 (diff)
downloadtaxi-0b4b65cb3d88ac4818e71ccef0bded3ddee0683c.tar.gz
taxi-0b4b65cb3d88ac4818e71ccef0bded3ddee0683c.zip
Fix floatX!=float32 in hdf5 creation
-rw-r--r--data/__init__.py7
-rwxr-xr-xdata/csv_to_hdf5.py9
-rwxr-xr-xdata/init_valid.py5
-rw-r--r--data/rfc4180.py (renamed from data/csv.py)0
4 files changed, 9 insertions, 12 deletions
diff --git a/data/__init__.py b/data/__init__.py
index 1278e0b..2121033 100644
--- a/data/__init__.py
+++ b/data/__init__.py
@@ -2,12 +2,11 @@ import os
import h5py
import numpy
-import theano
path = os.environ.get('TAXI_PATH', '/data/lisatmp3/auvolat/taxikaggle')
-Polyline = h5py.special_dtype(vlen=theano.config.floatX)
+Polyline = h5py.special_dtype(vlen=numpy.float32)
# `wc -l test.csv` - 1 # Minus 1 to ignore the header
test_size = 320
@@ -27,5 +26,5 @@ origin_call_size = 57125 # include 0 ("no origin_call")
# As printed by csv_to_hdf5.py
origin_call_train_size = 57106
-train_gps_mean = numpy.array([41.1573, -8.61612], dtype=theano.config.floatX)
-train_gps_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=theano.config.floatX))
+train_gps_mean = numpy.array([41.1573, -8.61612], dtype=numpy.float32)
+train_gps_std = numpy.sqrt(numpy.array([0.00549598, 0.00333233], dtype=numpy.float32))
diff --git a/data/csv_to_hdf5.py b/data/csv_to_hdf5.py
index 17217f3..97cf428 100755
--- a/data/csv_to_hdf5.py
+++ b/data/csv_to_hdf5.py
@@ -7,7 +7,6 @@ import sys
import h5py
import numpy
-import theano
from fuel.converters.base import fill_hdf5_file
import data
@@ -32,8 +31,8 @@ def get_unique_origin_call(val):
def read_stands(input_directory, h5file):
stands_name = numpy.empty(shape=(data.stands_size,), dtype=('a', 24))
- stands_latitude = numpy.empty(shape=(data.stands_size,), dtype=theano.config.floatX)
- stands_longitude = numpy.empty(shape=(data.stands_size,), dtype=theano.config.floatX)
+ stands_latitude = numpy.empty(shape=(data.stands_size,), dtype=numpy.float32)
+ stands_longitude = numpy.empty(shape=(data.stands_size,), dtype=numpy.float32)
stands_name[0] = 'None'
stands_latitude[0] = stands_longitude[0] = 0
with open(os.path.join(input_directory, 'metaData_taxistandsID_name_GPSlocation.csv'), 'r') as f:
@@ -77,8 +76,8 @@ def read_taxis(input_directory, h5file, dataset):
day_type[id] = ord(line[6][0]) - ord('A')
missing_data[id] = line[7][0] == 'T'
polyline = ast.literal_eval(line[8])
- latitude[id] = numpy.array([point[1] for point in polyline], dtype=theano.config.floatX)
- longitude[id] = numpy.array([point[0] for point in polyline], dtype=theano.config.floatX)
+ latitude[id] = numpy.array([point[1] for point in polyline], dtype=numpy.float32)
+ longitude[id] = numpy.array([point[0] for point in polyline], dtype=numpy.float32)
id+=1
splits = ()
print >> sys.stderr, 'read %s: writing' % dataset
diff --git a/data/init_valid.py b/data/init_valid.py
index 14a854c..eed0059 100755
--- a/data/init_valid.py
+++ b/data/init_valid.py
@@ -6,7 +6,6 @@ import sys
import h5py
import numpy
-import theano
import data
@@ -22,8 +21,8 @@ _fields = {
'missing_data': numpy.bool,
'latitude': data.Polyline,
'longitude': data.Polyline,
- 'destination_latitude': theano.config.floatX,
- 'destination_longitude': theano.config.floatX,
+ 'destination_latitude': numpy.float32,
+ 'destination_longitude': numpy.float32,
'travel_time': numpy.uint32,
}
diff --git a/data/csv.py b/data/rfc4180.py
index b6fe5b1..b6fe5b1 100644
--- a/data/csv.py
+++ b/data/rfc4180.py