1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
| @torch.no_grad def word_cle( logits_arc: FP[T, "batch dependent=token head=token"], word_first_token_mask: Bool[T, "batch dependent=token"], ) -> Int[T, "batch dependent=token"]: """Calculate the MST of words using Chu-Liu-Edmonds algorithm.
Args: logits_arc: `FP[T, "batch dependent=token head=token"]` Logits of the arc prediction. word_first_token: `Bool[T, "batch token"]` Index mask of the first tokens of the words.
Returns: `Int[T, "batch dependent=token]` Predicted arcs corresponding to each word, non-first token would be set as 0. """ assert not torch.any(word_first_token_mask[:, 0]), "index 0 must be the root" results = torch.zeros_like(word_first_token_mask, device="cpu", dtype=int) word_first_token_mask = word_first_token_mask.to(device="cpu", copy=True).numpy(), word_first_token_mask[:, 0] = True for logits, mask, res in zip( logits_arc.to(device="cpu", dtype=torch.float64).numpy(), word_first_token_mask, results, ): word_logits = logits[mask][:, mask] word_preds, _ = chu_liu_edmonds(word_logits) wordid_to_first_tokenid = np.where(mask)[0] res[mask] = torch.from_numpy(wordid_to_first_tokenid[word_preds[1:]]) return results
class RobertaDependencyParser(RobertaDependencyParserBase): def __init__( self, backbone: XLMRobertaModel, padding_label: int, training_backbone: bool, dim_h_arc: int, ): super().__init__(backbone, padding_label, training_backbone) self.mlp_h_arc_head = Mlp([self._dim_backbone, dim_h_arc]) self.mlp_h_arc_dep = Mlp([self._dim_backbone, dim_h_arc]) self.linear_u12 = torch.nn.Linear(dim_h_arc, dim_h_arc)
self.mlp_h_rel_head = Mlp([self._dim_backbone, dim_h_rel]) self.mlp_h_rel_dep = Mlp([self._dim_backbone, dim_h_rel]) self.linear_u345 = torch.nn.Linear(dim_h_rel * 2, num_edge_classes) self.mat_w = torch.nn.Parameter( torch.randn(num_edge_classes, dim_h_rel, dim_h_rel) )
def _forward_arc( self, backbone_out: FP[T, "batch token _dim_backbone"], heads: Int[T, "batch dependent=token"], ) -> tuple[ FP[T, ""], FP[T, "batch dependent=token head=token"], Int[T, "batch dependent=token"], ]: """ Returns: A tuple contains - `FP[T, ""]` Cross entropy loss of arc prediction. - `FP[T, "batch dependent=token head=token"]` Logits of the arc prediction. - `Int[T, "batch dependent=token"]` Predicted arcs. """ BATCH, TOKEN, _ = backbone_out.shape
h_head: FP[T, "batch head dim_h_arc"] = self.mlp_h_arc_head(backbone_out) h_dep: FP[T, "batch dependent dim_h_arc"] = self.mlp_h_arc_dep(backbone_out)
term: FP[T, "batch dependent dim_h_arc"] = self.linear_u12(h_dep) logits: FP[T, "batch dependent head"] = torch.einsum( "bid,bjd->bij", term, h_head ) arcs = word_cle(logits, heads != self._padding_label).to(heads.device)
loss = torch.nn.functional.cross_entropy( logits.reshape(BATCH * TOKEN, TOKEN), heads.reshape(BATCH * TOKEN), ignore_index=self._padding_label, ) return loss, logits, arcs
def forward( self, input_ids: Int[T, "batch token"], attention_mask: Int[T, "batch token"], heads: Int[T, "batch dependent=token"], deprels: Int[T, "batch dependent=token"], **_, ) -> tuple[FP[T, ""], dict[str, T]]: r"""Dependency Parsing based on RoBERTa, with edge label prediction.
Args: input_ids: `Int[T, "batch token"]` Batched tokens indices, attention_mask: `Int[T, "batch token"]` Attention mask passed to the backbone, heads: `Int[T, "batch dependent=token"]` Ground truth to calculate the loss.
Returns: A tuple contains - `FP[T, ""]` Loss $L_{arc} + L_{rel}$. - Named outputs as a dict - logits_arc: `FP[T, "batch dependent=token head=token"]` Logits of the arc prediction. - logits_rel: `FP[T, "batch dependent=token num_edge_classes"]` Logits of the rel prediction. - arcs: `Int[T, "batch dependent=token"]` Predicted arcs. """ assert not ((heads is None) ^ (deprels is None)) BATCH, TOKEN = input_ids.shape backbone_out: FP[T, "batch token self._dim_backbone"] = self.backbone.forward( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state
loss, logits_arc, arcs = self._forward_arc(backbone_out, heads) pred_heads: Int["batch dependent"] = arcs if self.training: pred_heads = heads.clone() pred_heads[heads == self._padding_label] = 0
h_rel_dep: FP[T, "batch dependent dim_h_rel"] = self.mlp_h_rel_dep(backbone_out) h_rel_head: FP[T, "batch head dim_h_rel"] = self.mlp_h_rel_head(backbone_out) h_rel_head: FP[T, "batch dependent dim_h_rel"] = torch.gather( h_rel_head, 1, pred_heads.reshape(BATCH, TOKEN, 1).expand_as(h_rel_head) )
term_a: FP[T, "batch dependent num_edge_classes"] = torch.einsum( "bid,cde,bie->bic", h_rel_dep, self.mat_w, h_rel_head ) term_b: FP[T, "batch dependent _"] = torch.cat([h_rel_dep, h_rel_head], -1) term_b: FP[T, "batch dependent num_edge_classes"] = self.linear_u345(term_b) logits_rel: FP[T, "batch dependent num_edge_classes"] = term_a + term_b
deprels = deprels.clone() deprels[pred_heads != heads] = self._padding_label if torch.any(deprels != self._padding_label): loss = loss + torch.nn.functional.cross_entropy( logits_rel.reshape(BATCH * TOKEN, -1), deprels.reshape(BATCH * TOKEN), ignore_index=self._padding_label, )
return loss, {"logits_arc": logits_arc, "logits_rel": logits_rel, "arcs": arcs}
|