Skip to content

Commit

Permalink
ConnectionManager can now be destroyed safely
Browse files Browse the repository at this point in the history
Reduces chance of any issues related to ConnectionManager
not being destroyed safely due to un-terminated spawns.
  • Loading branch information
allada committed Apr 19, 2024
1 parent 5583a5d commit 502e952
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 8 deletions.
19 changes: 19 additions & 0 deletions nativelink-config/src/stores.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;

use serde::{Deserialize, Serialize};

use crate::serde_utils::{
Expand Down Expand Up @@ -534,16 +536,33 @@ pub struct ClientTlsConfig {
pub key_file: Option<String>,
}

/// Header value type. This type represents the value of a header.
// TODO(allada) We should add a command and refresh rate type that
// can be used to generate short-lived auto-refreshing headers for
// thins like OAuth2's Bearer token.
#[allow(non_camel_case_types)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum HeaderValueType {
/// A string value.
value(String),
}

#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(deny_unknown_fields)]
pub struct GrpcEndpoint {
/// The endpoint address (i.e. grpc(s)://example.com:443).
#[serde(deserialize_with = "convert_string_with_shellexpand")]
pub address: String,

/// The TLS configuration to use to connect to the endpoint (if grpcs).
pub tls_config: Option<ClientTlsConfig>,

/// The maximum concurrency to allow on this endpoint.
pub concurrency_limit: Option<usize>,

/// Additional headers to send with the request.
#[serde(default)]
pub additional_headers: HashMap<String, HeaderValueType>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down
42 changes: 34 additions & 8 deletions nativelink-util/src/connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::time::Duration;
use futures::stream::{unfold, FuturesUnordered, StreamExt};
use futures::Future;
use nativelink_config::stores::Retry;
use nativelink_error::{make_err, Code, Error};
use nativelink_error::{make_err, Code, Error, ResultExt};
use tokio::sync::{mpsc, oneshot};
use tonic::transport::{channel, Channel, Endpoint};
use tracing::{debug, error, info, warn};
Expand All @@ -32,7 +32,10 @@ use crate::retry::{self, Retrier, RetryResult};
/// upstream gRPC endpoint using Tonic.
pub struct ConnectionManager {
// The channel to request connections from the worker.
worker_tx: mpsc::Sender<oneshot::Sender<Connection>>,
worker_tx: Option<mpsc::Sender<oneshot::Sender<Connection>>>,

// The handle to the worker that manage's spawn the connections.
service_spawn: Option<tokio::task::JoinHandle<()>>,
}

/// The index into ConnectionManagerWorker::endpoints.
Expand Down Expand Up @@ -148,12 +151,14 @@ impl ConnectionManager {
retry,
),
};
tokio::spawn(async move {
worker
.service_requests(connections_per_endpoint, worker_rx, connection_rx)
.await;
});
Self { worker_tx }
Self {
worker_tx: Some(worker_tx),
service_spawn: Some(tokio::spawn(async move {
worker
.service_requests(connections_per_endpoint, worker_rx, connection_rx)
.await;
})),
}
}

/// Get a Connection that can be used as a tonic::Channel, except it
Expand All @@ -162,6 +167,8 @@ impl ConnectionManager {
pub async fn connection(&self) -> Result<Connection, Error> {
let (tx, rx) = oneshot::channel();
self.worker_tx
.as_ref()
.err_tip(|| "ConnectionManager is already dropped")?
.send(tx)
.await
.map_err(|err| make_err!(Code::Unavailable, "Requesting a new connection: {err:?}"))?;
Expand All @@ -170,6 +177,25 @@ impl ConnectionManager {
}
}

impl Drop for ConnectionManager {
fn drop(&mut self) {
// Drop our worker_tx to signal to the worker that it should shut down.
drop(self.worker_tx.take());

let service_spawn = self
.service_spawn
.take()
.expect("Expected ConnectionManager::service_spawn to exist");
tokio::pin!(service_spawn);
service_spawn.abort();
// Wait for the worker to shut down.
let result = tokio::runtime::Handle::current().block_on(service_spawn);
if let Err(err) = result {
error!("Error while dropping ConnectionManager: {err:?}");
}
}
}

impl ConnectionManagerWorker {
async fn service_requests(
mut self,
Expand Down

0 comments on commit 502e952

Please sign in to comment.