Why can’t you pickle generators in Python? A pattern for saving training state

Sum­mary

A pat­tern for per­sist­ing gen­er­a­tors is to turn them into pickle-able class objects. This is use­ful when you use gen­er­a­tors for stream­ing train­ing examples.

I would also try generator_tools, which might be a more con­ve­nient alter­na­tive to the pat­tern I describe. I haven’t used it yet.


Gen­er­a­tors for stream­ing train­ing examples

For machine learn­ing, python gen­er­a­tors are a sim­ple idiom that make it easy to gen­er­ate a stream of train­ing exam­ples. More­over, you can nest generators:

  • The inner gen­er­a­tor can be used to read one exam­ple at a time.
  • The outer gen­er­a­tor can be used to read exam­ples from the inner gen­er­a­tor until you have a full mini­batch, and then yield this minibatch.

Here is some exam­ple code:

[Update: The exam­ple holds with­out the ALL CAPS magic vari­able names, “HYPERPARAMETERS”. How­ever, I include HYPERPARAMETERS because I am includ­ing the actual code I am using. Hyper­pa­ra­me­ters are global, read-only vari­ables that spec­ify the par­tic­u­lar exper­i­men­tal con­di­tion being tested. I can’t say that I have the best solu­tion to this par­tic­u­lar aspect of exper­i­men­tal con­trol (hyper­pa­ra­me­ters). I might write a blog post about it in the future, to solicit feed­back on improved meth­ods. How­ever, I have refined my cur­rent approach over sev­eral years, and I can assure you that it is far less painful than a hand­ful of more “clean” approaches.]

def get_train_example():
    HYPERPARAMETERS = common.hyperparameters.read("language-model")

    from vocabulary import wordmap
    for l in myopen(HYPERPARAMETERS["TRAIN_SENTENCES"]):
        prevwords = []
        for w in string.split(l):
            w = string.strip(w)
            id = None
            if wordmap.exists(w):
                prevwords.append(wordmap.id(w))
                if len(prevwords) >= HYPERPARAMETERS["WINDOW_SIZE"]:
                    yield prevwords[-HYPERPARAMETERS["WINDOW_SIZE"]:]
            else:
                prevwords = []

def get_train_minibatch():
    HYPERPARAMETERS = common.hyperparameters.read("language-model")
    minibatch = []
    for e in get_train_example():
        minibatch.append(e)
        if len(minibatch) >= HYPERPARAMETERS["MINIBATCH SIZE"]:
            assert len(minibatch) == HYPERPARAMETERS["MINIBATCH SIZE"]
            yield minibatch
            minibatch = []

You can’t per­sist train­ing state by pick­ling your generators

How­ever, gen­er­a­tors become prob­lem­atic when you want to per­sist your experiment’s state in order to later restart train­ing at the same place. Unfor­tu­nately, you can’t pickle gen­er­a­tors in Python. And it can be a bit of a PITA to workaround this, in order to save the train­ing state.

Pat­tern to workaround this annoyance

Fol­low­ing use­ful dis­cus­sion on pylearn-dev and stack­over­flow [1] [2], I pro­pose the fol­low­ing pat­tern for con­vert­ing gen­er­a­tors to pickle-able class objects:

  1. Con­vert the gen­er­a­tor to a class in which the gen­er­a­tor code is the __iter__ method
  2. Add __getstate__ and __setstate__ meth­ods to the class, to han­dling pick­ling. Remem­ber that you can’t pickle file objects. So __setstate__ will have to re-open files, as necessary.

Here is the updated code, after apply­ing this pattern:

class TrainingExampleStream(object):
    def __init__(self):
        # Set the state variables, in case pickling happens before __iter__ is called.
        self.filename = None
        self.count = 0
        pass

    def __iter__(self):
        HYPERPARAMETERS = common.hyperparameters.read("language-model")
        from vocabulary import wordmap
        self.filename = HYPERPARAMETERS["TRAIN_SENTENCES"]
        self.count = 0
        for l in myopen(self.filename):
            prevwords = []
            for w in string.split(l):
                w = string.strip(w)
                id = None
                if wordmap.exists(w):
                    prevwords.append(wordmap.id(w))
                    if len(prevwords) >= HYPERPARAMETERS["WINDOW_SIZE"]:
                        self.count += 1
                        yield prevwords[-HYPERPARAMETERS["WINDOW_SIZE"]:]
                else:
                    prevwords = []

    def __getstate__(self):
        return self.filename, self.count

    def __setstate__(self, state):
        """
        @warning: We ignore the filename.  If we wanted
        to be really fastidious, we would assume that
        HYPERPARAMETERS["TRAIN_SENTENCES"] might change.  The only
        problem is that if we change filesystems, the filename
        might change just because the base file is in a different
        path. So we issue a warning if the filename is different from what is expected.
        """
        filename, count = state
        print >> sys.stderr, ("__setstate__(%s)..." % `state`)
        iter = self.__iter__()
        while count != self.count:
#            print count, self.count
            iter.next()
        if self.filename != filename:
            assert self.filename == HYPERPARAMETERS["TRAIN_SENTENCES"]
            print >> sys.stderr, ("self.filename %s != filename given to __setstate__ %s" % (self.filename, filename))
        print >> sys.stderr, ("...__setstate__(%s)" % `state`)

class TrainingMinibatchStream(object):
    def __init__(self):
        pass

    def __iter__(self):
        HYPERPARAMETERS = common.hyperparameters.read("language-model")
        minibatch = []
        self.get_train_example = TrainingExampleStream()
        for e in self.get_train_example:
            minibatch.append(e)
            if len(minibatch) >= HYPERPARAMETERS["MINIBATCH SIZE"]:
                assert len(minibatch) == HYPERPARAMETERS["MINIBATCH SIZE"]
                yield minibatch
                minibatch = []

    def __getstate__(self):
        return (self.get_train_example.__getstate__(),)

    def __setstate__(self, state):
        """
        @warning: We ignore the filename.
        """
        self.get_train_example = TrainingExampleStream()
        self.get_train_example.__setstate__(state[0])
  • http://disinterest.org Richard Tew
  • http://josephturian.blogspot.com/ Joseph Turian

    Response to crit­i­cism on hack­ernews:

    (1) You can only pickle gen­er­a­tors that gen­er­ate the same sequence every time they are restarted.

    I don’t know how you can per­sist state if you do not make this assumption.

    (2) All the work the gen­er­a­tor did prior to pick­ling must be per­formed again on unpickling.

    Some­thing faster would be to use file.tell() to get the state and file.seek() to set the state. Since the “unpick­ling” is not a bot­tle­neck, I didn’t opti­mize this.

  • http://peadrop.com/blog/2009/12/29/why-you-cannot-pickle-generator/ Why you can­not pickle generator

    […] Turian wrote a post about regard­ing pick­ling gen­er­a­tor on his blog. In his post, he says: How­ever, gen­er­a­tors become prob­lem­atic when you want to persist […]

  • http://twitter.com/turian/status/26008922196 Joseph Turian

    Why can’t you pickle gen­er­a­tors in Python? A pat­tern for sav­ing train­ing state http://bit.ly/dvUzTP

blog comments powered by Disqus