diff --git a/decoder.py b/decoder.py index 339a1f5..e594a5b 100644 --- a/decoder.py +++ b/decoder.py @@ -57,6 +57,11 @@ def run_sampler(dec, c, beam_width=1, stochastic=False, use_unk=False): """ Generate text conditioned on c """ + if stochastic and beam_width > 1: + print ("Beam search does not support stochastic sampling. " + + "Setting beam_width to 1\n") + beam_width = 1 + sample, score = gen_sample(dec['tparams'], dec['f_init'], dec['f_next'], c.reshape(1, dec['options']['dimctx']), dec['options'], trng=dec['trng'], k=beam_width, maxlen=1000, stochastic=stochastic, @@ -64,6 +69,7 @@ def run_sampler(dec, c, beam_width=1, stochastic=False, use_unk=False): text = [] if stochastic: sample = [sample] + score = [score] for c in sample: text.append(' '.join([dec['word_idict'][w] for w in c[:-1]])) diff --git a/generate.py b/generate.py index 9d8a789..70da521 100644 --- a/generate.py +++ b/generate.py @@ -31,7 +31,7 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True -def story(z, image_loc, k=100, bw=50, lyric=False): +def story(z, image_loc, k=100, bw=50, lyric=False, stochastic=False): """ Generate a story for an image at location image_loc """ @@ -62,7 +62,7 @@ def story(z, image_loc, k=100, bw=50, lyric=False): shift = svecs.mean(0) - z['bneg'] + z['bpos'] # Generate story conditioned on shift - passage = decoder.run_sampler(z['dec'], shift, beam_width=bw) + passage = decoder.run_sampler(z['dec'], shift, beam_width=bw, stochastic=stochastic) print 'OUTPUT: ' if lyric: for line in passage.split(','):