Skip to content

Commit

Permalink
whisper : add mechanism for aborting the whisper_full() computation
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Nov 27, 2022
1 parent 6fd5358 commit 4698dcd
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
13 changes: 13 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,19 @@ int main(int argc, char ** argv) {
wparams.new_segment_callback_user_data = &user_data;
}

// example for abort mechanism
// in this example, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race

wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}

if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
return 10;
Expand Down
13 changes: 13 additions & 0 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2451,6 +2451,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,

/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
};
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
Expand Down Expand Up @@ -2497,6 +2500,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.new_segment_callback =*/ nullptr,
/*.new_segment_callback_user_data =*/ nullptr,

/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,
};
} break;
}
Expand Down Expand Up @@ -2659,6 +2665,13 @@ int whisper_full(
break;
}

if (params.encoder_begin_callback) {
if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) {
fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__);
break;
}
}

// encode audio features starting at offset seek
if (whisper_encode(ctx, seek, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to encode\n", __func__);
Expand Down
11 changes: 11 additions & 0 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ extern "C" {
// Use the whisper_full_...() functions to obtain the text segments
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);

// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);

// Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
struct whisper_full_params {
enum whisper_sampling_strategy strategy;

Expand Down Expand Up @@ -231,6 +239,9 @@ extern "C" {

whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;

whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
};

WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
Expand Down

0 comments on commit 4698dcd

Please sign in to comment.