From 0b4b65cb3d88ac4818e71ccef0bded3ddee0683c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89tienne=20Simon?= Date: Wed, 6 May 2015 11:54:52 -0400 Subject: Fix floatX!=float32 in hdf5 creation --- data/csv_to_hdf5.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'data/csv_to_hdf5.py') diff --git a/data/csv_to_hdf5.py b/data/csv_to_hdf5.py index 17217f3..97cf428 100755 --- a/data/csv_to_hdf5.py +++ b/data/csv_to_hdf5.py @@ -7,7 +7,6 @@ import sys import h5py import numpy -import theano from fuel.converters.base import fill_hdf5_file import data @@ -32,8 +31,8 @@ def get_unique_origin_call(val): def read_stands(input_directory, h5file): stands_name = numpy.empty(shape=(data.stands_size,), dtype=('a', 24)) - stands_latitude = numpy.empty(shape=(data.stands_size,), dtype=theano.config.floatX) - stands_longitude = numpy.empty(shape=(data.stands_size,), dtype=theano.config.floatX) + stands_latitude = numpy.empty(shape=(data.stands_size,), dtype=numpy.float32) + stands_longitude = numpy.empty(shape=(data.stands_size,), dtype=numpy.float32) stands_name[0] = 'None' stands_latitude[0] = stands_longitude[0] = 0 with open(os.path.join(input_directory, 'metaData_taxistandsID_name_GPSlocation.csv'), 'r') as f: @@ -77,8 +76,8 @@ def read_taxis(input_directory, h5file, dataset): day_type[id] = ord(line[6][0]) - ord('A') missing_data[id] = line[7][0] == 'T' polyline = ast.literal_eval(line[8]) - latitude[id] = numpy.array([point[1] for point in polyline], dtype=theano.config.floatX) - longitude[id] = numpy.array([point[0] for point in polyline], dtype=theano.config.floatX) + latitude[id] = numpy.array([point[1] for point in polyline], dtype=numpy.float32) + longitude[id] = numpy.array([point[0] for point in polyline], dtype=numpy.float32) id+=1 splits = () print >> sys.stderr, 'read %s: writing' % dataset -- cgit v1.2.3