diff options
-rw-r--r-- | data_analysis/maps.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/data_analysis/maps.py b/data_analysis/maps.py index e951f23..d5db182 100644 --- a/data_analysis/maps.py +++ b/data_analysis/maps.py @@ -10,13 +10,14 @@ from data.hdf5 import TaxiDataset, TaxiStream def compute_number_coordinates(): - stream = TaxiDataset('train').get_example_stream() + dataset = TaxiDataset('train') + stream = DataStream(dataset, iteration_scheme=ConstantScheme(1, dataset.num_examples)) train_it = stream.get_epoch_iterator() # Count the number of coordinates n_coordinates = 0 for ride in train_it: - n_coordinates += len(ride[-1]) + n_coordinates += len(ride[2]) print n_coordinates return n_coordinates @@ -51,7 +52,7 @@ def draw_map(coordinates, xrg, yrg): hist, xx, yy = np.histogram2d(coordinates[:, 0], coordinates[:, 1], bins=2000, range=[xrg, yrg]) plt.imshow(np.log(hist)) - plt.savefig(data.DATA_PATH + "/analysis/xyhmap2.png") + plt.savefig(data.path + "/analysis/xyhmap2.png") if __name__ == "__main__": |