Summary
A pattern for persisting generators is to turn them into pickle-able class objects. This is useful when you use generators for streaming training examples.
I would also try generator_tools, which might be a more convenient alternative to the pattern I describe. I haven’t used it yet.
Generators for streaming training examples
For machine learning, python generators are a simple idiom that make it easy to generate a stream of training examples. Moreover, you can nest generators:
- The inner generator can be used to read one example at a time.
- The outer generator can be used to read examples from the inner generator until you have a full minibatch, and then yield this minibatch.
Here is some example code:
[Update: The example holds without the ALL CAPS magic variable names, “HYPERPARAMETERS”. However, I include HYPERPARAMETERS because I am including the actual code I am using. Hyperparameters are global, read-only variables that specify the particular experimental condition being tested. I can’t say that I have the best solution to this particular aspect of experimental control (hyperparameters). I might write a blog post about it in the future, to solicit feedback on improved methods. However, I have refined my current approach over several years, and I can assure you that it is far less painful than a handful of more “clean” approaches.]
def get_train_example():
HYPERPARAMETERS = common.hyperparameters.read("language-model")
from vocabulary import wordmap
for l in myopen(HYPERPARAMETERS["TRAIN_SENTENCES"]):
prevwords = []
for w in string.split(l):
w = string.strip(w)
id = None
if wordmap.exists(w):
prevwords.append(wordmap.id(w))
if len(prevwords) >= HYPERPARAMETERS["WINDOW_SIZE"]:
yield prevwords[-HYPERPARAMETERS["WINDOW_SIZE"]:]
else:
prevwords = []
def get_train_minibatch():
HYPERPARAMETERS = common.hyperparameters.read("language-model")
minibatch = []
for e in get_train_example():
minibatch.append(e)
if len(minibatch) >= HYPERPARAMETERS["MINIBATCH SIZE"]:
assert len(minibatch) == HYPERPARAMETERS["MINIBATCH SIZE"]
yield minibatch
minibatch = []
You can’t persist training state by pickling your generators
However, generators become problematic when you want to persist your experiment’s state in order to later restart training at the same place. Unfortunately, you can’t pickle generators in Python. And it can be a bit of a PITA to workaround this, in order to save the training state.
Pattern to workaround this annoyance
Following useful discussion on pylearn-dev and stackoverflow [1] [2], I propose the following pattern for converting generators to pickle-able class objects:
- Convert the generator to a class in which the generator code is the __iter__ method
- Add __getstate__ and __setstate__ methods to the class, to handling pickling. Remember that you can’t pickle file objects. So __setstate__ will have to re-open files, as necessary.
Here is the updated code, after applying this pattern:
class TrainingExampleStream(object):
def __init__(self):
# Set the state variables, in case pickling happens before __iter__ is called.
self.filename = None
self.count = 0
pass
def __iter__(self):
HYPERPARAMETERS = common.hyperparameters.read("language-model")
from vocabulary import wordmap
self.filename = HYPERPARAMETERS["TRAIN_SENTENCES"]
self.count = 0
for l in myopen(self.filename):
prevwords = []
for w in string.split(l):
w = string.strip(w)
id = None
if wordmap.exists(w):
prevwords.append(wordmap.id(w))
if len(prevwords) >= HYPERPARAMETERS["WINDOW_SIZE"]:
self.count += 1
yield prevwords[-HYPERPARAMETERS["WINDOW_SIZE"]:]
else:
prevwords = []
def __getstate__(self):
return self.filename, self.count
def __setstate__(self, state):
"""
@warning: We ignore the filename. If we wanted
to be really fastidious, we would assume that
HYPERPARAMETERS["TRAIN_SENTENCES"] might change. The only
problem is that if we change filesystems, the filename
might change just because the base file is in a different
path. So we issue a warning if the filename is different from what is expected.
"""
filename, count = state
print >> sys.stderr, ("__setstate__(%s)..." % `state`)
iter = self.__iter__()
while count != self.count:
# print count, self.count
iter.next()
if self.filename != filename:
assert self.filename == HYPERPARAMETERS["TRAIN_SENTENCES"]
print >> sys.stderr, ("self.filename %s != filename given to __setstate__ %s" % (self.filename, filename))
print >> sys.stderr, ("...__setstate__(%s)" % `state`)
class TrainingMinibatchStream(object):
def __init__(self):
pass
def __iter__(self):
HYPERPARAMETERS = common.hyperparameters.read("language-model")
minibatch = []
self.get_train_example = TrainingExampleStream()
for e in self.get_train_example:
minibatch.append(e)
if len(minibatch) >= HYPERPARAMETERS["MINIBATCH SIZE"]:
assert len(minibatch) == HYPERPARAMETERS["MINIBATCH SIZE"]
yield minibatch
minibatch = []
def __getstate__(self):
return (self.get_train_example.__getstate__(),)
def __setstate__(self, state):
"""
@warning: We ignore the filename.
"""
self.get_train_example = TrainingExampleStream()
self.get_train_example.__setstate__(state[0])
