| | import copy |
| | import operator |
| |
|
| | import attr |
| |
|
| |
|
| | @attr.s |
| | class Hypothesis: |
| | inference_state = attr.ib() |
| | next_choices = attr.ib() |
| | score = attr.ib(default=0) |
| |
|
| | choice_history = attr.ib(factory=list) |
| | score_history = attr.ib(factory=list) |
| |
|
| |
|
| | def beam_search(model, orig_item, preproc_item, beam_size, max_steps, visualize_flag=False): |
| | inference_state, next_choices = model.begin_inference(orig_item, preproc_item) |
| | beam = [Hypothesis(inference_state, next_choices)] |
| | finished = [] |
| |
|
| | for step in range(max_steps): |
| | if visualize_flag: |
| | print('step:') |
| | print(step) |
| | |
| | if len(finished) == beam_size: |
| | break |
| |
|
| | candidates = [] |
| |
|
| | |
| | |
| | for hyp in beam: |
| | candidates += [(hyp, choice, choice_score.item(), |
| | hyp.score + choice_score.item()) |
| | for choice, choice_score in hyp.next_choices] |
| |
|
| | |
| | candidates.sort(key=operator.itemgetter(3), reverse=True) |
| | candidates = candidates[:beam_size - len(finished)] |
| |
|
| | |
| | beam = [] |
| | for hyp, choice, choice_score, cum_score in candidates: |
| | inference_state = hyp.inference_state.clone() |
| | next_choices = inference_state.step(choice) |
| | if next_choices is None: |
| | finished.append(Hypothesis( |
| | inference_state, |
| | None, |
| | cum_score, |
| | hyp.choice_history + [choice], |
| | hyp.score_history + [choice_score])) |
| | else: |
| | beam.append( |
| | Hypothesis(inference_state, next_choices, cum_score, |
| | hyp.choice_history + [choice], |
| | hyp.score_history + [choice_score])) |
| |
|
| | finished.sort(key=operator.attrgetter('score'), reverse=True) |
| | return finished |
| |
|