aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data_analysis/maps.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/data_analysis/maps.py b/data_analysis/maps.py
index e951f23..d5db182 100644
--- a/data_analysis/maps.py
+++ b/data_analysis/maps.py
@@ -10,13 +10,14 @@ from data.hdf5 import TaxiDataset, TaxiStream
def compute_number_coordinates():
- stream = TaxiDataset('train').get_example_stream()
+ dataset = TaxiDataset('train')
+ stream = DataStream(dataset, iteration_scheme=ConstantScheme(1, dataset.num_examples))
train_it = stream.get_epoch_iterator()
# Count the number of coordinates
n_coordinates = 0
for ride in train_it:
- n_coordinates += len(ride[-1])
+ n_coordinates += len(ride[2])
print n_coordinates
return n_coordinates
@@ -51,7 +52,7 @@ def draw_map(coordinates, xrg, yrg):
hist, xx, yy = np.histogram2d(coordinates[:, 0], coordinates[:, 1], bins=2000, range=[xrg, yrg])
plt.imshow(np.log(hist))
- plt.savefig(data.DATA_PATH + "/analysis/xyhmap2.png")
+ plt.savefig(data.path + "/analysis/xyhmap2.png")
if __name__ == "__main__":