Skip to content

Commit

Permalink
Implement worker api for killing running actions
Browse files Browse the repository at this point in the history
Implements worker api for requesting a currently running action to be
killed. Allows for action cancellation requests to be sent from the
scheduler during scenarios such as client disconnection.

towards TraceMachina#338
  • Loading branch information
Zach Birenbaum committed Apr 8, 2024
1 parent 6b9e68e commit c55983d
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ message ConnectionResult {
reserved 2; // NextId.
}

/// Request to kill a running action sent from the scheduler to a worker.
message KillActionRequest {
/// The the hex encoded unique qualifier for the action to be killed.
string action_id = 1;
reserved 2; // NextId.
}
/// Communication from the scheduler to the worker.
message UpdateForWorker {
oneof update {
Expand All @@ -152,8 +158,11 @@ message UpdateForWorker {
/// Informs the worker that it has been disconnected from the pool.
/// The worker may discard any outstanding work that is being executed.
google.protobuf.Empty disconnect = 4;

/// Instructs the worker to kill a specific running action.
KillActionRequest kill_action_request = 5;
}
reserved 5; // NextId.
reserved 6; // NextId.
}

message StartExecute {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,19 @@ pub struct ConnectionResult {
#[prost(string, tag = "1")]
pub worker_id: ::prost::alloc::string::String,
}
/// / Request to kill a running action sent from the scheduler to a worker.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct KillActionRequest {
/// / The the hex encoded unique qualifier for the action to be killed.
#[prost(string, tag = "1")]
pub action_id: ::prost::alloc::string::String,
}
/// / Communication from the scheduler to the worker.
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct UpdateForWorker {
#[prost(oneof = "update_for_worker::Update", tags = "1, 2, 3, 4")]
#[prost(oneof = "update_for_worker::Update", tags = "1, 2, 3, 4, 5")]
pub update: ::core::option::Option<update_for_worker::Update>,
}
/// Nested message and enum types in `UpdateForWorker`.
Expand All @@ -133,6 +141,9 @@ pub mod update_for_worker {
/// / The worker may discard any outstanding work that is being executed.
#[prost(message, tag = "4")]
Disconnect(()),
/// / Instructs the worker to kill a specific running action.
#[prost(message, tag = "5")]
KillActionRequest(super::KillActionRequest),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
Expand Down
1 change: 1 addition & 0 deletions nativelink-worker/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ rust_test_suite(
"//nativelink-util",
"@crates//:async-lock",
"@crates//:futures",
"@crates//:hex",
"@crates//:hyper",
"@crates//:once_cell",
"@crates//:pretty_assertions",
Expand Down
13 changes: 13 additions & 0 deletions nativelink-worker/src/local_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,19 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a,
Update::KeepAlive(()) => {
self.metrics.keep_alives_received.inc();
}
Update::KillActionRequest(kill_action_request) => {
let mut action_id = [0u8; 32];
hex::decode_to_slice(kill_action_request.action_id, &mut action_id as &mut [u8])
.map_err(|e| make_input_err!(
"KillActionRequest failed to decode ActionId hex with error {}",
e
))?;

self.running_actions_manager
.kill_action(action_id)
.await
.err_tip(|| format!("Failed to send kill request for action {}", hex::encode(action_id)))?
}
Update::StartAction(start_execute) => {
self.metrics.start_actions_received.inc();
let add_future_channel = add_future_channel.clone();
Expand Down
16 changes: 16 additions & 0 deletions nativelink-worker/src/running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,8 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static {

async fn kill_all(&self);

async fn kill_action(&self, action_id: ActionId) -> Result<(), Error>;

fn metrics(&self) -> &Arc<Metrics>;
}

Expand Down Expand Up @@ -1783,6 +1785,20 @@ impl RunningActionsManager for RunningActionsManagerImpl {
.await
}

async fn kill_action(&self, action_id: ActionId) -> Result<(), Error> {
let running_action = {
let running_actions = self.running_actions.lock();
running_actions
.get(&action_id)
.and_then(|action| action.upgrade())
.ok_or_else(|| {
make_input_err!("Failed to get running action {}", hex::encode(action_id))
})?
};
Self::kill_action(running_action).await;
Ok(())
}

// Note: When the future returns the process should be fully killed and cleaned up.
async fn kill_all(&self) {
self.metrics
Expand Down
91 changes: 91 additions & 0 deletions nativelink-worker/tests/local_worker_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ fn make_temp_path(data: &str) -> String {

#[cfg(test)]
mod local_worker_tests {
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::KillActionRequest;
use pretty_assertions::assert_eq;

use super::*; // Must be declared in every module.
Expand Down Expand Up @@ -638,4 +639,94 @@ mod local_worker_tests {

Ok(())
}

#[tokio::test]
async fn kill_action_request_kills_action() -> Result<(), Box<dyn std::error::Error>> {
const SALT: u64 = 1000;

let mut test_context = setup_local_worker(HashMap::new()).await;

let streaming_response = test_context.maybe_streaming_response.take().unwrap();

{
// Ensure our worker connects and properties were sent.
let props = test_context
.client
.expect_connect_worker(Ok(streaming_response))
.await;
assert_eq!(props, SupportedProperties::default());
}

// Handle registration (kill_all not called unless registered).
let mut tx_stream = test_context.maybe_tx_stream.take().unwrap();
{
tx_stream
.send_data(encode_stream_proto(&UpdateForWorker {
update: Some(Update::ConnectionResult(ConnectionResult {
worker_id: "foobar".to_string(),
})),
})?)
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}

let action_digest = DigestInfo::new([3u8; 32], 10);
let action_info = ActionInfo {
command_digest: DigestInfo::new([1u8; 32], 10),
input_root_digest: DigestInfo::new([2u8; 32], 10),
timeout: Duration::from_secs(1),
platform_properties: PlatformProperties::default(),
priority: 0,
load_timestamp: SystemTime::UNIX_EPOCH,
insert_timestamp: SystemTime::UNIX_EPOCH,
unique_qualifier: ActionInfoHashKey {
instance_name: INSTANCE_NAME.to_string(),
digest: action_digest,
salt: SALT,
},
skip_cache_lookup: true,
digest_function: DigestHasherFunc::Blake3,
};

{
// Send execution request.
tx_stream
.send_data(encode_stream_proto(&UpdateForWorker {
update: Some(Update::StartAction(StartExecute {
execute_request: Some(action_info.clone().into()),
salt: SALT,
queued_timestamp: None,
})),
})?)
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}
let running_action = Arc::new(MockRunningAction::new());

// Send and wait for response from create_and_add_action to RunningActionsManager.
test_context
.actions_manager
.expect_create_and_add_action(Ok(running_action.clone()))
.await;

let action_id = action_info.unique_qualifier.get_hash();
{
// Send kill request.
tx_stream
.send_data(encode_stream_proto(&UpdateForWorker {
update: Some(Update::KillActionRequest(KillActionRequest {
action_id: hex::encode(action_id),
})),
})?)
.await
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
}

let killed_action_id = test_context.actions_manager.expect_kill_action().await;

// Make sure that the killed action is the one we intended
assert_eq!(killed_action_id, action_id);

Ok(())
}
}
25 changes: 24 additions & 1 deletion nativelink-worker/tests/utils/mock_running_actions_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::
use nativelink_util::action_messages::ActionResult;
use nativelink_util::common::DigestInfo;
use nativelink_util::digest_hasher::DigestHasherFunc;
use nativelink_worker::running_actions_manager::{Metrics, RunningAction, RunningActionsManager};
use nativelink_worker::running_actions_manager::{
ActionId, Metrics, RunningAction, RunningActionsManager,
};
use tokio::sync::mpsc;

#[derive(Debug)]
Expand All @@ -43,6 +45,9 @@ pub struct MockRunningActionsManager {

rx_kill_all: Mutex<mpsc::UnboundedReceiver<()>>,
tx_kill_all: mpsc::UnboundedSender<()>,

rx_kill_action: Mutex<mpsc::UnboundedReceiver<ActionId>>,
tx_kill_action: mpsc::UnboundedSender<ActionId>,
metrics: Arc<Metrics>,
}

Expand All @@ -57,13 +62,16 @@ impl MockRunningActionsManager {
let (tx_call, rx_call) = mpsc::unbounded_channel();
let (tx_resp, rx_resp) = mpsc::unbounded_channel();
let (tx_kill_all, rx_kill_all) = mpsc::unbounded_channel();
let (tx_kill_action, rx_kill_action) = mpsc::unbounded_channel();
Self {
rx_call: Mutex::new(rx_call),
tx_call,
rx_resp: Mutex::new(rx_resp),
tx_resp,
rx_kill_all: Mutex::new(rx_kill_all),
tx_kill_all,
rx_kill_action: Mutex::new(rx_kill_action),
tx_kill_action,
metrics: Arc::new(Metrics::default()),
}
}
Expand Down Expand Up @@ -108,6 +116,14 @@ impl MockRunningActionsManager {
.await
.expect("Could not receive msg in mpsc");
}

pub async fn expect_kill_action(&self) -> ActionId {
let mut rx_kill_action_lock = self.rx_kill_action.lock().await;
rx_kill_action_lock
.recv()
.await
.expect("Could not receive msg in mpsc")
}
}

#[async_trait]
Expand Down Expand Up @@ -151,6 +167,13 @@ impl RunningActionsManager for MockRunningActionsManager {
Ok(())
}

async fn kill_action(&self, action_id: ActionId) -> Result<(), Error> {
self.tx_kill_action
.send(action_id)
.expect("Could not send request to mpsc");
Ok(())
}

async fn kill_all(&self) {
self.tx_kill_all
.send(())
Expand Down

0 comments on commit c55983d

Please sign in to comment.