@@ -223,6 +223,12 @@ def beam_search(
223
223
encoder_output_key = "last_hidden_state" if self .is_huggingface_model else "encoder_output"
224
224
encoder_output = model_kwargs ["encoder_outputs" ][encoder_output_key ]
225
225
226
+ num_sequences = input_ids .shape [0 ]
227
+
228
+ # Pre-allocate everything
229
+ token_idxs = torch .full ((num_sequences , num_beams , 1 ), eos_idx ).to (dtype = torch .long , device = device )
230
+ beam_idxs = torch .zeros ((num_sequences , num_beams , 1 )).to (dtype = torch .long , device = device )
231
+
226
232
def update_func (emissions , N , T , prev_step_token_idxs , prev_step_hyp_idxs , prev_step_model_states , timestep ):
227
233
# `emissions` and `N` are unused in this current implementation
228
234
@@ -231,16 +237,8 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
231
237
# For first timestep, create previous step token_idxs and model_states
232
238
if timestep == 0 :
233
239
prev_step_token_idxs = [- 1 ]
234
- prev_step_model_states = [
235
- create_emitting_model_state (
236
- Seq2SeqModelState (timestep = 0 , sequence = input_ids [i ].unsqueeze (0 ), lm_scores = None )
237
- )
238
- ]
239
240
240
241
encoder_output_for_curr_seq = encoder_output [i , :, :].unsqueeze (0 ) if self .is_encoder_decoder else None
241
- prev_model_state_sequences = [
242
- get_obj_from_emitting_model_state (state ).sequence for state in prev_step_model_states
243
- ]
244
242
out_probs , model_states = [], []
245
243
246
244
start = 0
@@ -256,66 +254,32 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
256
254
if end > curr_beam_size :
257
255
end = curr_beam_size
258
256
259
- num_samples = end - start
260
-
261
257
if prev_step_token_idxs != [- 1 ]:
262
- state_sequences = torch .cat (prev_model_state_sequences [start :end ], dim = 0 )
263
- token_indices = (
264
- torch .Tensor (prev_step_token_idxs [start :end ])
265
- .to (dtype = torch .long , device = device )
266
- .reshape (num_samples , 1 )
267
- )
268
-
269
- state_and_tokens = torch .cat (
270
- [state_sequences , token_indices ], dim = - 1
271
- ) # [batch_size x (timestep + 1)]
272
- assert state_and_tokens .shape == (
273
- num_samples ,
274
- timestep + 1 ,
275
- ), f"state_and_tokens has shape { state_and_tokens .shape } = expected { (num_samples , timestep + 1 )} "
258
+ token_indices = torch .Tensor (prev_step_token_idxs [start :end ]).to (dtype = torch .long , device = device )
259
+ token_idxs [i , : len (token_indices ), 0 ] = token_indices
260
+ curr_token_idxs = token_idxs [i , :, 0 ].reshape (num_beams , 1 )
276
261
else :
277
- assert len (prev_model_state_sequences ) == 1
278
- state_and_tokens = token_indices = prev_model_state_sequences [0 ].expand (
279
- num_beams , - 1
280
- ) # TODO: Make this more robust
281
-
282
- # Cleanup -- combine this with the above
283
- if self .is_encoder_decoder :
284
- # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
285
- # This is a view-only operation and doesn't copy
286
- model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
287
- num_samples if timestep > 0 else num_beams , - 1 , - 1
288
- )
262
+ if self .is_encoder_decoder :
263
+ # Expand encoder outputs along the batch dimension so that they match the decoder input state's batch size
264
+ # This is a view-only operation and doesn't copy
265
+ model_kwargs ["encoder_outputs" ][encoder_output_key ] = encoder_output_for_curr_seq .expand (
266
+ num_beams , - 1 , - 1
267
+ )
268
+ curr_token_idxs = torch .zeros ((num_beams , 1 )).to (dtype = torch .long , device = device )
269
+
289
270
290
271
# Preprocess inputs for generation
291
272
model_inputs = self .model .prepare_inputs_for_generation (
292
- token_indices , ** model_kwargs
273
+ curr_token_idxs , ** model_kwargs
293
274
) # This should technically work with state_and_tokens, but the prepare function has to splice if past (like HF does)
294
275
if self .is_huggingface_model :
295
276
model_inputs .update (self ._huggingface_model_input_values )
296
277
if len (prev_step_hyp_idxs ) > 1 and model_kwargs ["past" ] is not None :
297
- beam_idxs = torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 )
298
-
299
- # We could store this in model_kwargs
300
- num_hyps_in_prev_step = model_kwargs ["past" ][0 ][0 ].shape [0 ]
301
-
302
- num_finished_hyps_in_step = num_hyps_in_prev_step - len (prev_step_hyp_idxs )
303
- if num_finished_hyps_in_step > 0 :
304
- beam_idxs = F .pad (beam_idxs , (0 , num_finished_hyps_in_step ), "constant" , 0 )
305
-
306
- beam_idxs = torch .clamp (beam_idxs , max = len (prev_step_hyp_idxs ) - 1 )
307
-
308
- reordered_cached = self .model ._reorder_cache (model_kwargs ["past" ], beam_idxs )
309
-
310
- if num_finished_hyps_in_step > 0 :
311
- sliced_cache = ()
312
- for states in reordered_cached :
313
- sliced_state = ()
314
- for state in states :
315
- sliced_state = sliced_state + (state [: len (prev_step_hyp_idxs )],)
316
- sliced_cache = sliced_cache + (sliced_state ,)
317
- reordered_cached = sliced_cache
278
+ beam_indices = torch .Tensor (prev_step_hyp_idxs ).to (dtype = torch .int32 )
279
+ beam_idxs [i , : len (prev_step_hyp_idxs ), 0 ] = beam_indices
280
+ curr_beam_idxs = beam_idxs [i , :, 0 ]
318
281
282
+ reordered_cached = self .model ._reorder_cache (model_kwargs ["past" ], curr_beam_idxs )
319
283
model_inputs ["past_key_values" ] = reordered_cached
320
284
321
285
# Forward pass
@@ -329,18 +293,21 @@ def update_func(emissions, N, T, prev_step_token_idxs, prev_step_hyp_idxs, prev_
329
293
if self .is_huggingface_model :
330
294
self ._update_model_kwargs_for_generation (outputs , model_kwargs )
331
295
296
+ # Reset
297
+ token_idxs [i , :, 0 ] = eos_idx
298
+ beam_idxs [i , :, 0 ] = 0
299
+
332
300
# Keep track of probabilities over vocab for this pairing
333
- # TODO: fix how we track the number here?
334
- for i in range (lm_scores .shape [0 ]):
301
+ for i in range (num_beams ):
335
302
sample_lm_scores = lm_scores [i , - 1 ]
336
303
out_probs .append (sample_lm_scores .tolist ())
337
304
# Keep track of sequence and decoder hidden states
338
305
model_states .append (
339
306
create_emitting_model_state (
340
307
Seq2SeqModelState (
341
308
timestep = timestep ,
342
- sequence = state_and_tokens [ i ]. unsqueeze ( 0 ) ,
343
- lm_scores = sample_lm_scores ,
309
+ sequence = [] ,
310
+ lm_scores = 0 ,
344
311
)
345
312
)
346
313
)
@@ -386,10 +353,6 @@ def is_not_neg_one(elem: int) -> bool:
386
353
if not self .is_encoder_decoder :
387
354
final_tokens = input_ids [timestep ].tolist () + final_tokens
388
355
389
- # Makeshift padding so that we can stack the tensors
390
- while len (final_tokens ) < max_len :
391
- final_tokens += [0 ]
392
-
393
356
# Convert from list to tensors
394
357
final_tokens_as_tensors = torch .Tensor (final_tokens ).to (torch .long )
395
358
0 commit comments