PyTorch Implementation of Dependency Parsing

Prerequisites

Wikipedia: Dependency grammar

Assignment: Assignment: Dependency parsing

In a nutshell, you will proceed as follows:

  • From the XLM-RoBERTa embeddings of each token, extract representations H_head and H_dep using a one-layer MLP with some output dimension (see the D&M paper for suggestions on hyperparameters). Note that you need a separate MLP for the head and for the dep representation.

  • Calculate a score for each pair of a potential head and potential dependent , by multiplying H_head[i].T * U1 * H_dep[j] + H_head[i].T * u2. U1 is a matrix, and u2 is a -dimensional vector; their entries are parameters of the model which are learned in training.

Head Prediction

The original formula is: , where , , , , ( is the number of words and is a hyperparameter), and is the score that is the head and is the dependent.

With the assumption that every token must have exactly one head (except the root which has none), we would theoretically apply softmax along dimension of the logits to obtain the probability score . THowever, this deviates from standard PyTorch conventions, particularly concerning CrossEntropyLoss, which excepts softmax to be applied along the last dimension by default. To resolve this, I refactored the formula to calculate the transpose, . Consequently, represents the score of being the dependent and beging the head,allowing us to seamlessly apply softmax to the last dimension.

Consider the term , where , , and . Recall that the standard definition of a neural network linear layer is where , , and . By comparing the two, it becomes evident that this term is mathematically equivalent to a standard linear layer. Here, serves as the weight matrix, acts as the bias, and is the input. Thus, the equation simplifies to

Formulating this operation as a native PyTorch Linear layer not only results in cleaner code, but also allows PyTorch to leverage low-level optimizations (such as Tensor Cores) for significantly higher GPU utilization.

Edge Label Prediction

The original formula is: , in which , , , , , , is the number of edge classes.

For a given dependent with head , we exclusively need the score . Because and are coupled during this phase, we can eliminate the head dimension entirely. We achieve this by gathering the features of the predicted heads, effectively transforming the matrix into tensor of batched vectors, . Consequently, the computational complexity is drastically reduced from to .

Let us examine the term . By concatenating , we form a new weight matrix . Similarly, by concatenating and , we construct a joint feature matrix . The expression is then equivalent to . Once again, this structure perfectly aligns with a standard linear layer. Therefore, the final streamlined equation is:

Coding

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
@torch.no_grad
def word_cle(
logits_arc: FP[T, "batch dependent=token head=token"],
word_first_token_mask: Bool[T, "batch dependent=token"],
) -> Int[T, "batch dependent=token"]:
"""Calculate the MST of words using Chu-Liu-Edmonds algorithm.

Args:
logits_arc: `FP[T, "batch dependent=token head=token"]` Logits of the arc prediction.
word_first_token: `Bool[T, "batch token"]` Index mask of the first tokens of the words.

Returns:
`Int[T, "batch dependent=token]` Predicted arcs corresponding to each word, non-first token would be set as 0.
"""
assert not torch.any(word_first_token_mask[:, 0]), "index 0 must be the root"
results = torch.zeros_like(word_first_token_mask, device="cpu", dtype=int)
word_first_token_mask = word_first_token_mask.to(device="cpu", copy=True).numpy(),
word_first_token_mask[:, 0] = True # fetch the root for CLE
for logits, mask, res in zip(
logits_arc.to(device="cpu", dtype=torch.float64).numpy(),
word_first_token_mask,
results,
):
word_logits = logits[mask][:, mask] # T[1+word, 1+word]
# log-softmax is not need here. Check the Calculation Note in README
word_preds, _ = chu_liu_edmonds(word_logits) # T[1+word]
# map words to their corresponding first tokens.
wordid_to_first_tokenid = np.where(mask)[0]
res[mask] = torch.from_numpy(wordid_to_first_tokenid[word_preds[1:]])
# word_first_token_mask[:, 0] = False
return results

class RobertaDependencyParser(RobertaDependencyParserBase):
def __init__(
self,
backbone: XLMRobertaModel,
padding_label: int,
training_backbone: bool,
dim_h_arc: int,
):
super().__init__(backbone, padding_label, training_backbone)
self.mlp_h_arc_head = Mlp([self._dim_backbone, dim_h_arc])
self.mlp_h_arc_dep = Mlp([self._dim_backbone, dim_h_arc])
# weight as U1, bias as u2, check the Calculation Notes in README.
self.linear_u12 = torch.nn.Linear(dim_h_arc, dim_h_arc)

self.mlp_h_rel_head = Mlp([self._dim_backbone, dim_h_rel])
self.mlp_h_rel_dep = Mlp([self._dim_backbone, dim_h_rel])
# weight as [U3,U4], bias as u5, check the Calculation Notes in README.
self.linear_u345 = torch.nn.Linear(dim_h_rel * 2, num_edge_classes)
self.mat_w = torch.nn.Parameter(
torch.randn(num_edge_classes, dim_h_rel, dim_h_rel)
)

def _forward_arc(
self,
backbone_out: FP[T, "batch token _dim_backbone"],
heads: Int[T, "batch dependent=token"],
) -> tuple[
FP[T, ""],
FP[T, "batch dependent=token head=token"],
Int[T, "batch dependent=token"],
]:
"""
Returns:
A tuple contains
- `FP[T, ""]` Cross entropy loss of arc prediction.
- `FP[T, "batch dependent=token head=token"]` Logits of the arc prediction.
- `Int[T, "batch dependent=token"]` Predicted arcs.
"""
BATCH, TOKEN, _ = backbone_out.shape

# MLP on the last dim
h_head: FP[T, "batch head dim_h_arc"] = self.mlp_h_arc_head(backbone_out)
h_dep: FP[T, "batch dependent dim_h_arc"] = self.mlp_h_arc_dep(backbone_out)

term: FP[T, "batch dependent dim_h_arc"] = self.linear_u12(h_dep)
logits: FP[T, "batch dependent head"] = torch.einsum(
"bid,bjd->bij", term, h_head
)
arcs = word_cle(logits, heads != self._padding_label).to(heads.device)

# calculate cross entropy loss
loss = torch.nn.functional.cross_entropy(
logits.reshape(BATCH * TOKEN, TOKEN),
heads.reshape(BATCH * TOKEN),
ignore_index=self._padding_label,
)
return loss, logits, arcs

def forward(
self,
input_ids: Int[T, "batch token"],
attention_mask: Int[T, "batch token"],
heads: Int[T, "batch dependent=token"],
deprels: Int[T, "batch dependent=token"],
**_,
) -> tuple[FP[T, ""], dict[str, T]]:
r"""Dependency Parsing based on RoBERTa, with edge label prediction.

Args:
input_ids: `Int[T, "batch token"]` Batched tokens indices,
attention_mask: `Int[T, "batch token"]` Attention mask passed to the backbone,
heads: `Int[T, "batch dependent=token"]` Ground truth to calculate the loss.

Returns:
A tuple contains
- `FP[T, ""]` Loss $L_{arc} + L_{rel}$.
- Named outputs as a dict
- logits_arc: `FP[T, "batch dependent=token head=token"]` Logits of the arc prediction.
- logits_rel: `FP[T, "batch dependent=token num_edge_classes"]` Logits of the rel prediction.
- arcs: `Int[T, "batch dependent=token"]` Predicted arcs.
"""
# heads and deprels must be both provided or both not provided
assert not ((heads is None) ^ (deprels is None))
BATCH, TOKEN = input_ids.shape
backbone_out: FP[T, "batch token self._dim_backbone"] = self.backbone.forward(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state

# predict arcs and calculate L_{arc} firstly
loss, logits_arc, arcs = self._forward_arc(backbone_out, heads)
pred_heads: Int["batch dependent"] = arcs
if self.training: # use ground truth in training.
pred_heads = heads.clone()
pred_heads[heads == self._padding_label] = 0

# MLP on the last dim, get rel features
h_rel_dep: FP[T, "batch dependent dim_h_rel"] = self.mlp_h_rel_dep(backbone_out)
h_rel_head: FP[T, "batch head dim_h_rel"] = self.mlp_h_rel_head(backbone_out)
# only select feature vectors from the predicted head.
h_rel_head: FP[T, "batch dependent dim_h_rel"] = torch.gather(
h_rel_head, 1, pred_heads.reshape(BATCH, TOKEN, 1).expand_as(h_rel_head)
)

term_a: FP[T, "batch dependent num_edge_classes"] = torch.einsum(
"bid,cde,bie->bic", h_rel_dep, self.mat_w, h_rel_head
)
term_b: FP[T, "batch dependent _"] = torch.cat([h_rel_dep, h_rel_head], -1)
term_b: FP[T, "batch dependent num_edge_classes"] = self.linear_u345(term_b)
logits_rel: FP[T, "batch dependent num_edge_classes"] = term_a + term_b

# calculate cross entropy loss: L = L_{arc} + L_{rel}
# ignore incorrect edges
deprels = deprels.clone()
deprels[pred_heads != heads] = self._padding_label
if torch.any(deprels != self._padding_label):
loss = loss + torch.nn.functional.cross_entropy(
logits_rel.reshape(BATCH * TOKEN, -1),
deprels.reshape(BATCH * TOKEN),
ignore_index=self._padding_label,
)

return loss, {"logits_arc": logits_arc, "logits_rel": logits_rel, "arcs": arcs}