Skip to content

Commit 9f2d2d6

Browse files
committed
Reworked w/ pre-allocated matrices, verrrrrrrry slow
1 parent 0c3ccc6 commit 9f2d2d6

File tree

2 files changed

+30
-67
lines changed

2 files changed

+30
-67
lines changed

test/torchtext_unittest/prototype/test_generate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_hf_DELETE(self) -> None:
9090
test_sequence_tk,
9191
max_len=100,
9292
pad_idx=t5.config.pad_token_id,
93-
num_beams=10,
93+
num_beams=7,
9494
beam_size_token=t5.config.vocab_size,
9595
)
9696
end = time.time() - start

torchtext/prototype/generate.py

+29-66
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ def beam_search(
223223
encoder_output_key = "last_hidden_state" if self.is_huggingface_model else "encoder_output"
224224
encoder_output = model_kwargs["encoder_outputs"][encoder_output_key]
225225

226+
num_sequences = input_ids.shape[0]
227+
228+
# Pre-allocate everything
229+
token_idxs = torch.full((num_sequences, num_beams, 1), eos_idx).to(dtype=torch.long, device=device)
230+
beam_idxs = torch.zeros((num_sequences, num_beams, 1)).to(dtype=torch.long, device=device)
231+
226232
def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_step_model_states, timestep):
227233
# `emissions` and `N` are unused in this current implementation
228234

@@ -231,16 +237,8 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
231237
# For first timestep, create previous step token_idxs and model_states
232238
if timestep == 0:
233239
prev_step_token_idxs = [-1]
234-
prev_step_model_states = [
235-
create_emitting_model_state(
236-
Seq2SeqModelState(timestep=0, sequence=input_ids[i].unsqueeze(0), lm_scores=None)
237-
)
238-
]
239240

240241
encoder_output_for_curr_seq = encoder_output[i, :, :].unsqueeze(0) if self.is_encoder_decoder else None
241-
prev_model_state_sequences = [
242-
get_obj_from_emitting_model_state(state).sequence for state in prev_step_model_states
243-
]
244242
out_probs, model_states = [], []
245243

246244
start = 0
@@ -256,66 +254,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
256254
if end > curr_beam_size:
257255
end = curr_beam_size
258256

259-
num_samples = end - start
260-
261257
if prev_step_token_idxs != [-1]:
262-
state_sequences = torch.cat(prev_model_state_sequences[start:end], dim=0)
263-
token_indices = (
264-
torch.Tensor(prev_step_token_idxs[start:end])
265-
.to(dtype=torch.long, device=device)
266-
.reshape(num_samples, 1)
267-
)
268-
269-
state_and_tokens = torch.cat(
270-
[state_sequences, token_indices], dim=-1
271-
) # [batch_size x (timestep + 1)]
272-
assert state_and_tokens.shape == (
273-
num_samples,
274-
timestep + 1,
275-
), f"state_and_tokens has shape {state_and_tokens.shape} = expected {(num_samples, timestep + 1)}"
258+
token_indices = torch.Tensor(prev_step_token_idxs[start:end]).to(dtype=torch.long, device=device)
259+
token_idxs[i, : len(token_indices), 0] = token_indices
260+
curr_token_idxs = token_idxs[i, :, 0].reshape(num_beams, 1)
276261
else:
277-
assert len(prev_model_state_sequences) == 1
278-
state_and_tokens = token_indices = prev_model_state_sequences[0].expand(
279-
num_beams, -1
280-
) # TODO: Make this more robust
281-
282-
# Cleanup -- combine this with the above
283-
if self.is_encoder_decoder:
284-
# Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
285-
# This is a view-only operation and doesn't copy
286-
model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand(
287-
num_samples if timestep > 0 else num_beams, -1, -1
288-
)
262+
if self.is_encoder_decoder:
263+
# Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
264+
# This is a view-only operation and doesn't copy
265+
model_kwargs["encoder_outputs"][encoder_output_key] = encoder_output_for_curr_seq.expand(
266+
num_beams, -1, -1
267+
)
268+
curr_token_idxs = torch.zeros((num_beams, 1)).to(dtype=torch.long, device=device)
269+
289270

290271
# Preprocess inputs for generation
291272
model_inputs = self.model.prepare_inputs_for_generation(
292-
token_indices, **model_kwargs
273+
curr_token_idxs, **model_kwargs
293274
) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does)
294275
if self.is_huggingface_model:
295276
model_inputs.update(self._huggingface_model_input_values)
296277
if len(prev_step_hyp_idxs) > 1 and model_kwargs["past"] is not None:
297-
beam_idxs = torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32)
298-
299-
# We could store this in model_kwargs
300-
num_hyps_in_prev_step = model_kwargs["past"][0][0].shape[0]
301-
302-
num_finished_hyps_in_step = num_hyps_in_prev_step - len(prev_step_hyp_idxs)
303-
if num_finished_hyps_in_step > 0:
304-
beam_idxs = F.pad(beam_idxs, (0, num_finished_hyps_in_step), "constant", 0)
305-
306-
beam_idxs = torch.clamp(beam_idxs, max=len(prev_step_hyp_idxs) - 1)
307-
308-
reordered_cached = self.model._reorder_cache(model_kwargs["past"], beam_idxs)
309-
310-
if num_finished_hyps_in_step > 0:
311-
sliced_cache = ()
312-
for states in reordered_cached:
313-
sliced_state = ()
314-
for state in states:
315-
sliced_state = sliced_state + (state[: len(prev_step_hyp_idxs)],)
316-
sliced_cache = sliced_cache + (sliced_state,)
317-
reordered_cached = sliced_cache
278+
beam_indices = torch.Tensor(prev_step_hyp_idxs).to(dtype=torch.int32)
279+
beam_idxs[i, : len(prev_step_hyp_idxs), 0] = beam_indices
280+
curr_beam_idxs = beam_idxs[i, :, 0]
318281

282+
reordered_cached = self.model._reorder_cache(model_kwargs["past"], curr_beam_idxs)
319283
model_inputs["past_key_values"] = reordered_cached
320284

321285
# Forward pass
@@ -329,18 +293,21 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
329293
if self.is_huggingface_model:
330294
self._update_model_kwargs_for_generation(outputs, model_kwargs)
331295

296+
# Reset
297+
token_idxs[i, :, 0] = eos_idx
298+
beam_idxs[i, :, 0] = 0
299+
332300
# Keep track of probabilities over vocab for this pairing
333-
# TODO: fix how we track the number here?
334-
for i in range(lm_scores.shape[0]):
301+
for i in range(num_beams):
335302
sample_lm_scores = lm_scores[i, -1]
336303
out_probs.append(sample_lm_scores.tolist())
337304
# Keep track of sequence and decoder hidden states
338305
model_states.append(
339306
create_emitting_model_state(
340307
Seq2SeqModelState(
341308
timestep=timestep,
342-
sequence=state_and_tokens[i].unsqueeze(0),
343-
lm_scores=sample_lm_scores,
309+
sequence=[],
310+
lm_scores=0,
344311
)
345312
)
346313
)
@@ -386,10 +353,6 @@ def is_not_neg_one(elem: int) -> bool:
386353
if not self.is_encoder_decoder:
387354
final_tokens = input_ids[timestep].tolist() + final_tokens
388355

389-
# Makeshift padding so that we can stack the tensors
390-
while len(final_tokens) < max_len:
391-
final_tokens += [0]
392-
393356
# Convert from list to tensors
394357
final_tokens_as_tensors = torch.Tensor(final_tokens).to(torch.long)
395358

0 commit comments

Comments
 (0)