| | import torch |
| |
|
| | def compute_align_loss(model, desc_enc, example): |
| | '''model: a nl2code decoder''' |
| | |
| | root_node = example.tree |
| | rel_cols = list(reversed([val for val in model.ast_wrapper.find_all_descendants_of_type(root_node, "column")])) |
| | rel_tabs = list(reversed([val for val in model.ast_wrapper.find_all_descendants_of_type(root_node, "table")])) |
| |
|
| | rel_cols_t = torch.LongTensor(sorted(list(set(rel_cols)))).to(model._device) |
| | rel_tabs_t = torch.LongTensor(sorted(list(set(rel_tabs)))).to(model._device) |
| |
|
| | mc_att_on_rel_col = desc_enc.m2c_align_mat.index_select(1, rel_cols_t) |
| | mc_max_rel_att, _ = mc_att_on_rel_col.max(dim=0) |
| | mc_max_rel_att.clamp_(min=1e-9) |
| |
|
| | mt_att_on_rel_tab = desc_enc.m2t_align_mat.index_select(1, rel_tabs_t) |
| | mt_max_rel_att, _ = mt_att_on_rel_tab.max(dim=0) |
| | mt_max_rel_att.clamp_(min=1e-9) |
| |
|
| | c_num = desc_enc.m2c_align_mat.size()[1] |
| | un_rel_cols_t = torch.LongTensor(sorted(list(set(range(c_num)) - set(rel_cols)))).to(model._device) |
| | mc_att_on_unrel_col = desc_enc.m2c_align_mat.index_select(1, un_rel_cols_t) |
| | mc_max_unrel_att, _ = mc_att_on_unrel_col.max(dim=0) |
| | mc_max_unrel_att.clamp_(min=1e-9) |
| | mc_margin = torch.log(mc_max_unrel_att).mean() - torch.log(mc_max_rel_att).mean() |
| |
|
| | t_num = desc_enc.m2t_align_mat.size()[1] |
| | if t_num > len(set(rel_tabs)): |
| | un_rel_tabs_t = torch.LongTensor(sorted(list(set(range(t_num)) - set(rel_tabs)))).to(model._device) |
| | mt_att_on_unrel_tab = desc_enc.m2t_align_mat.index_select(1, un_rel_tabs_t) |
| | mt_max_unrel_att, _ = mt_att_on_unrel_tab.max(dim=0) |
| | mt_max_unrel_att.clamp_(min=1e-9) |
| | mt_margin = torch.log(mt_max_unrel_att).mean() - torch.log(mt_max_rel_att).mean() |
| | else: |
| | mt_margin = torch.tensor(0.0).to(model._device) |
| |
|
| | gamma = 1 |
| | |
| | align_loss = - torch.log(mc_max_rel_att).mean() - torch.log(mt_max_rel_att).mean() |
| | |
| | |
| | return align_loss |
| |
|
| |
|
| | def compute_pointer_with_align( |
| | model, |
| | node_type, |
| | prev_state, |
| | prev_action_emb, |
| | parent_h, |
| | parent_action_emb, |
| | desc_enc): |
| | new_state, attention_weights = model._update_state( |
| | node_type, prev_state, prev_action_emb, parent_h, |
| | parent_action_emb, desc_enc) |
| | |
| | output = new_state[0] |
| | memory_pointer_logits = model.pointers[node_type]( |
| | output, desc_enc.memory) |
| | memory_pointer_probs = torch.nn.functional.softmax(\ |
| | memory_pointer_logits, dim=1) |
| | |
| | if node_type == "column": |
| | pointer_probs = torch.mm(memory_pointer_probs, desc_enc.m2c_align_mat) |
| | else: |
| | assert node_type == "table" |
| | pointer_probs = torch.mm(memory_pointer_probs, desc_enc.m2t_align_mat) |
| | pointer_probs = pointer_probs.clamp(min=1e-9) |
| | pointer_logits = torch.log(pointer_probs) |
| | return output, new_state, pointer_logits, attention_weights |
| |
|