Why can’t you pickle generators in Python? A pattern for saving training state

March 10th, 2012   •   no comments   

Summary

A pattern for persisting generators is to turn them into pickle-able class objects. This is useful when you use generators for streaming training examples.

I would also try generator_tools, which might be a more convenient alternative to the pattern I describe. I haven’t used it yet.


Generators for streaming training examples

For machine learning, python generators are a simple idiom that make it easy to generate a stream of training examples. Moreover, you can nest generators:

  • The inner generator can be used to read one example at a time.
  • The outer generator can be used to read examples from the inner generator until you have a full minibatch, and then yield this minibatch.

Here is some example code:

[Update: The example holds without the ALL CAPS magic variable names, "HYPERPARAMETERS". However, I include HYPERPARAMETERS because I am including the actual code I am using. Hyperparameters are global, read-only variables that specify the particular experimental condition being tested. I can't say that I have the best solution to this particular aspect of experimental control (hyperparameters). I might write a blog post about it in the future, to solicit feedback on improved methods. However, I have refined my current approach over several years, and I can assure you that it is far less painful than a handful of more "clean" approaches.]

def get_train_example():
    HYPERPARAMETERS = common.hyperparameters.read("language-model")

    from vocabulary import wordmap
    for l in myopen(HYPERPARAMETERS["TRAIN_SENTENCES"]):
        prevwords = []
        for w in string.split(l):
            w = string.strip(w)
            id = None
            if wordmap.exists(w):
                prevwords.append(wordmap.id(w))
                if len(prevwords) >= HYPERPARAMETERS["WINDOW_SIZE"]:
                    yield prevwords[-HYPERPARAMETERS["WINDOW_SIZE"]:]
            else:
                prevwords = []

def get_train_minibatch():
    HYPERPARAMETERS = common.hyperparameters.read("language-model")
    minibatch = []
    for e in get_train_example():
        minibatch.append(e)
        if len(minibatch) >= HYPERPARAMETERS["MINIBATCH SIZE"]:
            assert len(minibatch) == HYPERPARAMETERS["MINIBATCH SIZE"]
            yield minibatch
            minibatch = []

You can’t persist training state by pickling your generators

However, generators become problematic when you want to persist your experiment’s state in order to later restart training at the same place. Unfortunately, you can’t pickle generators in Python. And it can be a bit of a PITA to workaround this, in order to save the training state.

Pattern to workaround this annoyance

Following useful discussion on pylearn-dev and stackoverflow [1] [2], I propose the following pattern for converting generators to pickle-able class objects:

  1. Convert the generator to a class in which the generator code is the __iter__ method
  2. Add __getstate__ and __setstate__ methods to the class, to handling pickling. Remember that you can’t pickle file objects. So __setstate__ will have to re-open files, as necessary.

Here is the updated code, after applying this pattern:

class TrainingExampleStream(object):
    def __init__(self):
        # Set the state variables, in case pickling happens before __iter__ is called.
        self.filename = None
        self.count = 0
        pass

    def __iter__(self):
        HYPERPARAMETERS = common.hyperparameters.read("language-model")
        from vocabulary import wordmap
        self.filename = HYPERPARAMETERS["TRAIN_SENTENCES"]
        self.count = 0
        for l in myopen(self.filename):
            prevwords = []
            for w in string.split(l):
                w = string.strip(w)
                id = None
                if wordmap.exists(w):
                    prevwords.append(wordmap.id(w))
                    if len(prevwords) >= HYPERPARAMETERS["WINDOW_SIZE"]:
                        self.count += 1
                        yield prevwords[-HYPERPARAMETERS["WINDOW_SIZE"]:]
                else:
                    prevwords = []

    def __getstate__(self):
        return self.filename, self.count

    def __setstate__(self, state):
        """
        @warning: We ignore the filename.  If we wanted
        to be really fastidious, we would assume that
        HYPERPARAMETERS["TRAIN_SENTENCES"] might change.  The only
        problem is that if we change filesystems, the filename
        might change just because the base file is in a different
        path. So we issue a warning if the filename is different from what is expected.
        """
        filename, count = state
        print >> sys.stderr, ("__setstate__(%s)..." % `state`)
        iter = self.__iter__()
        while count != self.count:
#            print count, self.count
            iter.next()
        if self.filename != filename:
            assert self.filename == HYPERPARAMETERS["TRAIN_SENTENCES"]
            print >> sys.stderr, ("self.filename %s != filename given to __setstate__ %s" % (self.filename, filename))
        print >> sys.stderr, ("...__setstate__(%s)" % `state`)

class TrainingMinibatchStream(object):
    def __init__(self):
        pass

    def __iter__(self):
        HYPERPARAMETERS = common.hyperparameters.read("language-model")
        minibatch = []
        self.get_train_example = TrainingExampleStream()
        for e in self.get_train_example:
            minibatch.append(e)
            if len(minibatch) >= HYPERPARAMETERS["MINIBATCH SIZE"]:
                assert len(minibatch) == HYPERPARAMETERS["MINIBATCH SIZE"]
                yield minibatch
                minibatch = []

    def __getstate__(self):
        return (self.get_train_example.__getstate__(),)

    def __setstate__(self, state):
        """
        @warning: We ignore the filename.
        """
        self.get_train_example = TrainingExampleStream()
        self.get_train_example.__setstate__(state[0])
no comments

Leave a Reply