From 072d26e766931007a0f243674f7dfdff5c3104e9 Mon Sep 17 00:00:00 2001 From: Thomas Mesnard Date: Mon, 28 Dec 2015 20:51:50 +0100 Subject: Add plot More TIMIT ; log domain TIMIT: more complexity Nice poster Beautify code (mostly, add comments) Add final stuff. --- timit.py | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 timit.py (limited to 'timit.py') diff --git a/timit.py b/timit.py new file mode 100644 index 0000000..cc58939 --- /dev/null +++ b/timit.py @@ -0,0 +1,55 @@ +import logging +import random +import numpy + +import cPickle + +from fuel.datasets import Dataset, IndexableDataset +from fuel.streams import DataStream +from fuel.schemes import IterationScheme, ConstantScheme, SequentialExampleScheme, ShuffledExampleScheme +from fuel.transformers import Batch, Mapping, SortMapping, Unpack, Padding, Transformer + +import sys +import os + +logging.basicConfig(level='INFO') +logger = logging.getLogger(__name__) + +class _balanced_batch_helper(object): + def __init__(self, key): + self.key = key + def __call__(self, data): + return data[self.key].shape[0] + +def setup_datastream(path, batch_size, sort_batch_count, valid=False): + A = numpy.load(os.path.join(path, ('valid_x_raw.npy' if valid else 'train_x_raw.npy'))) + B = numpy.load(os.path.join(path, ('valid_phn.npy' if valid else 'train_phn.npy'))) + C = numpy.load(os.path.join(path, ('valid_seq_to_phn.npy' if valid else 'train_seq_to_phn.npy'))) + + D = [B[x[0]:x[1], 2] for x in C] + + ds = IndexableDataset({'input': A, 'output': D}) + stream = DataStream(ds, iteration_scheme=ShuffledExampleScheme(len(A))) + + stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size * sort_batch_count)) + comparison = _balanced_batch_helper(stream.sources.index('input')) + stream = Mapping(stream, SortMapping(comparison)) + stream = Unpack(stream) + + stream = Batch(stream, iteration_scheme=ConstantScheme(batch_size, num_examples=len(A))) + stream = Padding(stream, mask_sources=['input', 'output']) + + return ds, stream + +if __name__ == "__main__": + ds, stream = setup_datastream(batch_size=2, + path='/home/lx.nobackup/datasets/timit/readable') + + for i, d in enumerate(stream.get_epoch_iterator()): + print '--' + print d + + + if i > 2: break + +# vim: set sts=4 ts=4 sw=4 tw=0 et : -- cgit v1.2.3