Skip to content

Commit

Permalink
Adding traits ProtocolMessageFlight and OpaqueProtocolMessageFlight
Browse files Browse the repository at this point in the history
  • Loading branch information
aeyno committed Jun 3, 2024
1 parent 4751e6f commit 2883748
Show file tree
Hide file tree
Showing 13 changed files with 292 additions and 166 deletions.
96 changes: 21 additions & 75 deletions puffin/src/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,97 +1,37 @@
use std::{fmt::Debug, marker::PhantomData};

use log::debug;
use std::fmt::Debug;

use crate::{
algebra::{signature::Signature, Matcher},
claims::{Claim, SecurityViolationPolicy},
codec::{Codec, Reader},
codec::Codec,
error::Error,
trace::Trace,
variable_data::VariableData,
};

/// Store a message flight, a vec of all the messages sent by the PUT between two steps
#[derive(Debug, Clone)]
pub struct MessageFlight<M: ProtocolMessage<O>, O: OpaqueProtocolMessage> {
pub messages: Vec<M>,
phantom: PhantomData<O>,
}

impl<M: ProtocolMessage<O>, O: OpaqueProtocolMessage> MessageFlight<M, O> {
pub fn new() -> Self {
MessageFlight {
messages: vec![],
phantom: PhantomData,
}
}

pub fn debug(&self, info: &str) {
debug!("{}: {:?}", info, self);
}
pub trait ProtocolMessageFlight<M: ProtocolMessage<O>, O: OpaqueProtocolMessage>:
Clone + Debug + From<M>
{
fn new() -> Self;
fn push(&mut self, msg: M);
fn debug(&self, info: &str);
}

/// Store a flight of opaque messages, a vec of all the messages sent by the PUT between two steps
#[derive(Debug, Clone)]
pub struct OpaqueMessageFlight<O: OpaqueProtocolMessage> {
pub messages: Vec<O>,
}

impl<O: OpaqueProtocolMessage> OpaqueMessageFlight<O> {
pub fn new() -> Self {
OpaqueMessageFlight { messages: vec![] }
}

pub fn debug(&self, info: &str) {
debug!("{}: {:?}", info, self);
}

pub fn get_encoding(self) -> Vec<u8> {
pub trait OpaqueProtocolMessageFlight<O: OpaqueProtocolMessage>:
Clone + Debug + Codec + From<O>
{
fn new() -> Self;
fn debug(&self, info: &str);
fn push(&mut self, msg: O);
fn get_encoding(self) -> Vec<u8> {
let mut buf = Vec::new();
self.encode(&mut buf);
buf
}
}

impl<M: ProtocolMessage<O>, O: OpaqueProtocolMessage> From<M> for MessageFlight<M, O> {
fn from(value: M) -> Self {
MessageFlight {
messages: vec![value],
phantom: PhantomData,
}
}
}

impl<M: ProtocolMessage<O>, O: OpaqueProtocolMessage> From<MessageFlight<M, O>>
for OpaqueMessageFlight<O>
{
fn from(value: MessageFlight<M, O>) -> Self {
OpaqueMessageFlight {
messages: value.messages.iter().map(|m| m.create_opaque()).collect(),
}
}
}

impl<O: OpaqueProtocolMessage> Codec for OpaqueMessageFlight<O> {
fn encode(&self, bytes: &mut Vec<u8>) {
for msg in &self.messages {
msg.encode(bytes);
}
}

fn read(_reader: &mut Reader) -> Option<Self> {
None
}
}

impl<O: OpaqueProtocolMessage> From<O> for OpaqueMessageFlight<O> {
fn from(value: O) -> Self {
OpaqueMessageFlight {
messages: vec![value],
}
}
}

/// A structured message. This type defines how all possible messages of a protocol.
/// Usually this is implemented using an `enum`.
pub trait ProtocolMessage<O: OpaqueProtocolMessage>: Clone + Debug {
Expand Down Expand Up @@ -133,6 +73,12 @@ pub trait ProtocolBehavior: 'static {

type ProtocolMessage: ProtocolMessage<Self::OpaqueProtocolMessage>;
type OpaqueProtocolMessage: OpaqueProtocolMessage;
type ProtocolMessageFlight: ProtocolMessageFlight<
Self::ProtocolMessage,
Self::OpaqueProtocolMessage,
>;
type OpaqueProtocolMessageFlight: OpaqueProtocolMessageFlight<Self::OpaqueProtocolMessage>
+ From<Self::ProtocolMessageFlight>;

type Matcher: Matcher
+ for<'a> TryFrom<&'a MessageResult<Self::ProtocolMessage, Self::OpaqueProtocolMessage>>;
Expand Down
2 changes: 1 addition & 1 deletion puffin/src/put.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub struct PutDescriptor {
/// Generic trait used to define the interface with a concrete library
/// implementing the protocol.
pub trait Put<PB: ProtocolBehavior>:
Stream<PB::ProtocolMessage, PB::OpaqueProtocolMessage> + 'static
Stream<PB::ProtocolMessage, PB::OpaqueProtocolMessage, PB::OpaqueProtocolMessageFlight> + 'static
{
/// Process incoming buffer, internal progress, can fill in the output buffer
fn progress(&mut self, agent_name: &AgentName) -> Result<(), Error>;
Expand Down
17 changes: 11 additions & 6 deletions puffin/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ use std::{
use log::error;

use crate::{
codec::Codec,
error::Error,
protocol::{
MessageResult, OpaqueMessageFlight, OpaqueProtocolMessage, ProtocolMessage,
MessageResult, OpaqueProtocolMessage, OpaqueProtocolMessageFlight, ProtocolMessage,
ProtocolMessageDeframer,
},
};

pub trait Stream<M: ProtocolMessage<O>, O: OpaqueProtocolMessage> {
fn add_to_inbound(&mut self, message_flight: &OpaqueMessageFlight<O>);
pub trait Stream<M: ProtocolMessage<O>, O: OpaqueProtocolMessage, F: OpaqueProtocolMessageFlight<O>>
{
fn add_to_inbound(&mut self, message_flight: &F);

/// Takes a single TLS message from the outbound channel
fn take_message_from_outbound(&mut self) -> Result<Option<MessageResult<M, O>>, Error>;
Expand Down Expand Up @@ -70,14 +70,19 @@ impl<D: ProtocolMessageDeframer> MemoryStream<D> {
}
}

impl<M, D: ProtocolMessageDeframer, E> Stream<M, D::OpaqueProtocolMessage> for MemoryStream<D>
impl<
M,
D: ProtocolMessageDeframer,
E,
F: OpaqueProtocolMessageFlight<D::OpaqueProtocolMessage>,
> Stream<M, D::OpaqueProtocolMessage, F> for MemoryStream<D>
where
M: ProtocolMessage<D::OpaqueProtocolMessage>,
D::OpaqueProtocolMessage: TryInto<M, Error = E>,
E: Into<Error>,
M: TryInto<M>,
{
fn add_to_inbound(&mut self, message_flight: &OpaqueMessageFlight<D::OpaqueProtocolMessage>) {
fn add_to_inbound(&mut self, message_flight: &F) {
message_flight.encode(self.inbound.get_mut());
}

Expand Down
17 changes: 8 additions & 9 deletions puffin/src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ use crate::{
claims::{Claim, GlobalClaimList, SecurityViolationPolicy},
error::Error,
protocol::{
MessageFlight, MessageResult, OpaqueMessageFlight, OpaqueProtocolMessage, ProtocolBehavior,
ProtocolMessage,
MessageResult, OpaqueProtocolMessage, OpaqueProtocolMessageFlight, ProtocolBehavior,
ProtocolMessage, ProtocolMessageFlight,
},
put::{PutDescriptor, PutOptions},
put_registry::PutRegistry,
Expand Down Expand Up @@ -242,7 +242,7 @@ impl<PB: ProtocolBehavior> TraceContext<PB> {
pub fn add_to_inbound(
&mut self,
agent_name: AgentName,
message_flight: &OpaqueMessageFlight<PB::OpaqueProtocolMessage>,
message_flight: &PB::OpaqueProtocolMessageFlight,
) -> Result<(), Error> {
self.find_agent_mut(agent_name)
.map(|agent| agent.put_mut().add_to_inbound(message_flight))
Expand Down Expand Up @@ -540,15 +540,15 @@ impl<M: Matcher> OutputAction<M> {
{
ctx.next_state(step.agent)?;

let mut flight = MessageFlight::new();
let mut flight = PB::ProtocolMessageFlight::new();

while let Some(message_result) = ctx.take_message_from_outbound(step.agent)? {
let matcher = message_result.create_matcher::<PB>();

let MessageResult(message, opaque_message) = message_result;

if let Some(m) = &message {
flight.messages.push(m.clone());
flight.push(m.clone());
}

let knowledge = message
Expand Down Expand Up @@ -634,23 +634,22 @@ impl<M: Matcher> InputAction<M> {

if let Some(flight) = evaluated
.as_ref()
.downcast_ref::<MessageFlight<PB::ProtocolMessage, PB::OpaqueProtocolMessage>>()
.downcast_ref::<PB::ProtocolMessageFlight>()
{
flight.debug("Input message flight");

ctx.add_to_inbound(step.agent, &flight.clone().into())?;
} else if let Some(flight) = evaluated
.as_ref()
.downcast_ref::<OpaqueMessageFlight<PB::OpaqueProtocolMessage>>()
.downcast_ref::<PB::OpaqueProtocolMessageFlight>()
{
flight.debug("Input opaque message flight");

ctx.add_to_inbound(step.agent, &flight)?;
} else if let Some(msg) = evaluated.as_ref().downcast_ref::<PB::ProtocolMessage>() {
msg.debug("Input message");

let message_flight: MessageFlight<PB::ProtocolMessage, PB::OpaqueProtocolMessage> =
msg.clone().into();
let message_flight: PB::ProtocolMessageFlight = msg.clone().into();
ctx.add_to_inbound(step.agent, &message_flight.into())?;
} else if let Some(opaque_message) = evaluated
.as_ref()
Expand Down
8 changes: 4 additions & 4 deletions sshpuffin/src/libssh/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use puffin::{
agent::{AgentDescriptor, AgentName, AgentType},
codec::Codec,
error::Error,
protocol::{MessageResult, OpaqueMessageFlight},
protocol::MessageResult,
put::{Put, PutName},
put_registry::{Factory, PutKind},
stream::Stream,
Expand All @@ -36,7 +36,7 @@ use crate::{
SessionOption, SessionState, SshAuthResult, SshBind, SshBindOption, SshKey, SshRequest,
SshResult, SshSession,
},
protocol::SshProtocolBehavior,
protocol::{RawSshMessageFlight, SshProtocolBehavior},
put_registry::LIBSSH_PUT,
ssh::{
deframe::SshMessageDeframer,
Expand Down Expand Up @@ -211,8 +211,8 @@ pub struct LibSSL {

impl LibSSL {}

impl Stream<SshMessage, RawSshMessage> for LibSSL {
fn add_to_inbound(&mut self, result: &OpaqueMessageFlight<RawSshMessage>) {
impl Stream<SshMessage, RawSshMessage, RawSshMessageFlight> for LibSSL {
fn add_to_inbound(&mut self, result: &RawSshMessageFlight) {
let mut buffer = Vec::new();
Codec::encode(result, &mut buffer);

Expand Down
92 changes: 91 additions & 1 deletion sshpuffin/src/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,106 @@
use log::debug;
use puffin::{
algebra::{signature::Signature, AnyMatcher},
protocol::ProtocolBehavior,
codec::{Codec, Reader},
protocol::{
OpaqueProtocolMessageFlight, ProtocolBehavior, ProtocolMessage, ProtocolMessageDeframer,
ProtocolMessageFlight,
},
trace::Trace,
};

use crate::{
claim::SshClaim,
ssh::{
deframe::SshMessageDeframer,
message::{RawSshMessage, SshMessage},
SSH_SIGNATURE,
},
violation::SshSecurityViolationPolicy,
};

#[derive(Debug, Clone)]
pub struct SshMessageFlight {
pub messages: Vec<SshMessage>,
}

impl ProtocolMessageFlight<SshMessage, RawSshMessage> for SshMessageFlight {
fn new() -> Self {
Self { messages: vec![] }
}

fn push(&mut self, msg: SshMessage) {
self.messages.push(msg);
}

fn debug(&self, info: &str) {
debug!("{}: {:?}", info, self);
}
}

impl From<SshMessage> for SshMessageFlight {
fn from(value: SshMessage) -> Self {
Self {
messages: vec![value],
}
}
}

#[derive(Debug, Clone)]
pub struct RawSshMessageFlight {
pub messages: Vec<RawSshMessage>,
}

impl OpaqueProtocolMessageFlight<RawSshMessage> for RawSshMessageFlight {
fn new() -> Self {
Self { messages: vec![] }
}

fn push(&mut self, msg: RawSshMessage) {
self.messages.push(msg);
}

fn debug(&self, info: &str) {
debug!("{}: {:?}", info, self);
}
}

impl Codec for RawSshMessageFlight {
fn encode(&self, bytes: &mut Vec<u8>) {
for msg in &self.messages {
msg.encode(bytes);
}
}

fn read(reader: &mut Reader) -> Option<Self> {
let mut deframer = SshMessageDeframer::new();
let mut flight = Self::new();

let _ = deframer.read(&mut reader.rest());
while let Some(msg) = deframer.pop_frame() {
flight.push(msg);
}

Some(flight)
}
}

impl From<SshMessageFlight> for RawSshMessageFlight {
fn from(value: SshMessageFlight) -> Self {
Self {
messages: value.messages.iter().map(|m| m.create_opaque()).collect(),
}
}
}

impl From<RawSshMessage> for RawSshMessageFlight {
fn from(value: RawSshMessage) -> Self {
Self {
messages: vec![value],
}
}
}

#[derive(Clone, Debug, PartialEq)]
pub struct SshProtocolBehavior {}

Expand All @@ -22,6 +110,8 @@ impl ProtocolBehavior for SshProtocolBehavior {
type ProtocolMessage = SshMessage;
type OpaqueProtocolMessage = RawSshMessage;
type Matcher = AnyMatcher;
type ProtocolMessageFlight = SshMessageFlight;
type OpaqueProtocolMessageFlight = RawSshMessageFlight;

fn signature() -> &'static Signature {
&SSH_SIGNATURE
Expand Down
Loading

0 comments on commit 2883748

Please sign in to comment.