Skip to content

Commit

Permalink
Hooks up preprocesing checking in GUI
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomaz-Vieira committed May 5, 2024
1 parent 0bfa064 commit 44406a4
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 100 deletions.
105 changes: 75 additions & 30 deletions bioimg_gui/src/widgets/inout_tensor_widget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ use std::sync::Arc;

use bioimg_runtime::model_interface::{InputSlot, OutputSlot};
use bioimg_runtime::npy_array::ArcNpyArray;
use bioimg_runtime::NpyArray;

use crate::result::{GuiError, Result};
use crate::result::{GuiError, Result, VecResultExt};
use bioimg_spec::rdf::model as modelrdf;
use bioimg_spec::rdf::model::input_tensor as rdfinput;

use super::error_display::show_error;
use super::error_display::{show_error, show_if_error};
use super::file_widget::{FileWidget, FileWidgetState};
use super::preprocessing_widget::PreprocessingWidget;
use super::staging_string::StagingString;
Expand All @@ -18,14 +20,29 @@ use super::{StatefulWidget, ValueWidget};
use crate::widgets::staging_vec::ItemWidgetConf;


#[derive(Default)]
pub struct InputTensorWidget {
pub id_widget: StagingString<modelrdf::TensorId>,
pub is_optional: bool,
pub description_widget: StagingString<modelrdf::TensorTextDescription>,
pub axes_widget: StagingVec<InputAxisWidget>,
pub test_tensor_widget: FileWidget<Result<ArcNpyArray>>,
pub preprocessing_widget: StagingVec<PreprocessingWidget>,

pub parsed: Result<InputSlot<Arc<NpyArray>>>,
}

impl Default for InputTensorWidget{
fn default() -> Self {
Self{
id_widget: Default::default(),
is_optional: Default::default(),
description_widget: Default::default(),
axes_widget: Default::default(),
test_tensor_widget: Default::default(),
preprocessing_widget: Default::default(),
parsed: Err(GuiError::new("empty".to_owned())),
}
}
}

impl ValueWidget for InputTensorWidget{
Expand All @@ -46,7 +63,7 @@ impl ItemWidgetConf for InputTensorWidget{
}

impl StatefulWidget for InputTensorWidget {
type Value<'p> = Result<InputSlot<ArcNpyArray>>;
type Value<'p> = &'p Result<InputSlot<ArcNpyArray>>;

fn draw_and_parse(&mut self, ui: &mut egui::Ui, id: egui::Id) {
if let FileWidgetState::Finished { path, value: Ok(gui_npy_arr) } = self.test_tensor_widget.state() {
Expand Down Expand Up @@ -98,32 +115,53 @@ impl StatefulWidget for InputTensorWidget {
ui.strong("Preprocessing: ");
self.preprocessing_widget.draw_and_parse(ui, id.with("preproc".as_ptr()));
});

self.parsed = || -> Result<InputSlot<Arc<NpyArray>>> {
let FileWidgetState::Finished { value: Ok(gui_npy_array), .. } = self.test_tensor_widget.state() else {
return Err(GuiError::new("Test tensor is missing".into()));
};
let axes = self.axes_widget.state().into_iter().collect::<Result<Vec<_>>>()?;
let input_axis_group = modelrdf::InputAxisGroup::try_from(axes)?; //FIXME: parse in draw_and_parse?
let meta_msg = rdfinput::InputTensorMetadataMsg{
id: self.id_widget.state()?,
optional: self.is_optional,
preprocessing: self.preprocessing_widget.state().collect_result()?,
description: self.description_widget.state()?,
axes: input_axis_group,
};
Ok(
InputSlot{ tensor_meta: meta_msg.try_into()?, test_tensor: Arc::clone(gui_npy_array) }
)
}();

show_if_error(ui, &self.parsed);
});
}

fn state<'p>(&'p self) -> Self::Value<'p> {
let FileWidgetState::Finished { value: Ok(gui_npy_array), .. } = self.test_tensor_widget.state() else {
return Err(GuiError::new("Test tensor is missing".into()));
};
let axes = self.axes_widget.state().into_iter().collect::<Result<Vec<_>>>()?;
let input_axis_group = modelrdf::InputAxisGroup::try_from(axes)?; //FIXME: parse in draw_and_parse?
Ok( InputSlot {
id: self.id_widget.state()?,
optional: self.is_optional,
preprocessing: vec![], //FIXME
description: self.description_widget.state()?,
axes: input_axis_group,
test_tensor: Arc::clone(gui_npy_array),
})
&self.parsed
}
}

#[derive(Default)]
pub struct OutputTensorWidget {
pub id_widget: StagingString<modelrdf::TensorId>,
pub description_widget: StagingString<modelrdf::TensorTextDescription>,
pub axes_widget: StagingVec<OutputAxisWidget>,
pub test_tensor_widget: FileWidget<Result<ArcNpyArray>>,

pub parsed: Result<OutputSlot<Arc<NpyArray>>>,
}

impl Default for OutputTensorWidget{
fn default() -> Self {
Self{
id_widget: Default::default(),
description_widget: Default::default(),
axes_widget: Default::default(),
test_tensor_widget: Default::default(),
parsed: Err(GuiError::new("empty".to_owned()))
}
}
}

impl ValueWidget for OutputTensorWidget{
Expand All @@ -144,7 +182,7 @@ impl ItemWidgetConf for OutputTensorWidget{
}

impl StatefulWidget for OutputTensorWidget {
type Value<'p> = Result<OutputSlot<ArcNpyArray>>;
type Value<'p> = &'p Result<OutputSlot<ArcNpyArray>>;

fn draw_and_parse(&mut self, ui: &mut egui::Ui, id: egui::Id) {
if let FileWidgetState::Finished { path, value: Ok(gui_npy_arr) } = self.test_tensor_widget.state() {
Expand Down Expand Up @@ -188,20 +226,27 @@ impl StatefulWidget for OutputTensorWidget {
ui.strong("Axes: ");
self.axes_widget.draw_and_parse(ui, id.with("Axes"));
});
self.parsed = || -> Result<OutputSlot<Arc<NpyArray>>> {
let FileWidgetState::Finished { value: Ok(gui_npy_array), .. } = self.test_tensor_widget.state() else {
return Err(GuiError::new("Test tensor is missing".into()));
};
let axes = self.axes_widget.state().into_iter().collect::<Result<Vec<_>>>()?;
let axis_group = modelrdf::OutputAxisGroup::try_from(axes)?; //FIXME: parse in draw_and_parse?
let meta_msg = modelrdf::output_tensor::OutputTensorMetadataMsg{
id: self.id_widget.state()?,
description: self.description_widget.state()?,
axes: axis_group,
};
Ok(
OutputSlot{ tensor_meta: meta_msg.try_into()?, test_tensor: Arc::clone(gui_npy_array) }
)
}();

show_if_error(ui, &self.parsed);
});
}

fn state<'p>(&'p self) -> Self::Value<'p> {
let FileWidgetState::Finished { value: Ok(gui_npy_array), .. } = self.test_tensor_widget.state() else {
return Err(GuiError::new("Test tensor is missing".into()));
};
let axes = self.axes_widget.state().into_iter().collect::<Result<Vec<_>>>()?;
let input_axis_group = modelrdf::OutputAxisGroup::try_from(axes)?; //FIXME: parse in draw_and_parse?
Ok( OutputSlot {
id: self.id_widget.state()?,
description: self.description_widget.state()?,
axes: input_axis_group,
test_tensor: Arc::clone(gui_npy_array),
})
&self.parsed
}
}
4 changes: 2 additions & 2 deletions bioimg_gui/src/widgets/model_interface_widget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ impl StatefulWidget for ModelInterfaceWidget {
self.outputs_widget.draw_and_parse(ui, id.with("out"));
});

let inputs = match self.inputs_widget.state().into_iter().collect::<Result<Vec<_>>>() {
let inputs = match self.inputs_widget.state().into_iter().map(|i| i.clone()).collect::<Result<Vec<_>>>() {
Ok(inps) => inps,
Err(_) => {
show_error(ui, format!("Check inputs for errors"));
return;
}
};
let outputs = match self.outputs_widget.state().into_iter().collect::<Result<Vec<_>>>() {
let outputs = match self.outputs_widget.state().into_iter().map(|i| i.clone()).collect::<Result<Vec<_>>>() {
Ok(outs) => outs,
Err(_) => {
show_error(ui, format!("Check outputs for errors"));
Expand Down
50 changes: 19 additions & 31 deletions bioimg_runtime/src/model_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@ use super::axis_size_resolver::AxisSizeResolutionError;
#[allow(dead_code)]
#[derive(Clone)]
pub struct InputSlot <DATA: Borrow<NpyArray>> {
pub id: TensorId,
pub optional: bool,
pub preprocessing: Vec<modelrdf::PreprocessingDescr>,
pub description: modelrdf::TensorTextDescription,
pub axes: modelrdf::InputAxisGroup,
pub tensor_meta: modelrdf::input_tensor::InputTensorMetadata,
pub test_tensor: DATA,
}

Expand All @@ -30,20 +26,16 @@ impl<DATA: Borrow<NpyArray>> InputSlot <DATA> {
&self,
zip_file: &mut ModelZipWriter<impl Write + Seek>,
) -> Result<modelrdf::InputTensorDescr, ModelPackingError> {
let test_tensor_zip_path = rdf::FsPath::unique_suffixed(&format!("_{}_test_tensor.npy", self.id));
let test_tensor_zip_path = rdf::FsPath::unique_suffixed(&format!("_{}_test_tensor.npy", self.tensor_meta.id));
zip_file.write_file(&test_tensor_zip_path, |writer| self.test_tensor.borrow().write_npy(writer))?;
Ok(modelrdf::input_tensor::InputTensorDescrMessage{
id: self.id.clone(),
optional: self.optional,
preprocessing: self.preprocessing.clone(),
description: self.description.clone(),
axes: self.axes.clone(),
Ok(modelrdf::input_tensor::InputTensorDescr{
meta: self.tensor_meta.clone(),
test_tensor: rdf::FileDescription{
source: test_tensor_zip_path.into(),
sha256: None,
},
sample_tensor: None, //FIXME
}.try_into()?)
})
}
}

Expand All @@ -56,9 +48,9 @@ pub trait VecInputSlotExt{
impl<DATA: Borrow<NpyArray>> VecInputSlotExt for [InputSlot<DATA>]{
fn qual_id_axes(&self) -> impl Iterator<Item=(QualifiedAxisId, &InputAxis)>{
self.iter()
.map(|rt_tensor_descr|{
rt_tensor_descr.axes.iter().map(|axis|{
let qual_id = QualifiedAxisId{tensor_id: rt_tensor_descr.id.clone(), axis_id: axis.id()};
.map(|slot|{
slot.tensor_meta.axes().iter().map(|axis|{
let qual_id = QualifiedAxisId{tensor_id: slot.tensor_meta.id.clone(), axis_id: axis.id()};
(qual_id, axis)
})
})
Expand All @@ -69,9 +61,7 @@ impl<DATA: Borrow<NpyArray>> VecInputSlotExt for [InputSlot<DATA>]{
#[allow(dead_code)]
#[derive(Clone)]
pub struct OutputSlot<DATA: Borrow<NpyArray>> {
pub id: TensorId,
pub description: modelrdf::TensorTextDescription,
pub axes: modelrdf::OutputAxisGroup,
pub tensor_meta: modelrdf::output_tensor::OutputTensorMetadata,
pub test_tensor: DATA,
}

Expand All @@ -80,12 +70,10 @@ impl<DATA: Borrow<NpyArray>> OutputSlot<DATA> {
&self,
zip_file: &mut ModelZipWriter<impl Write + Seek>,
) -> Result<modelrdf::OutputTensorDescr, ModelPackingError> {
let test_tensor_zip_path = rdf::FsPath::unique_suffixed(&format!("_{}_test_tensor.npy", self.id));
let test_tensor_zip_path = rdf::FsPath::unique_suffixed(&format!("_{}_test_tensor.npy", self.tensor_meta.id));
zip_file.write_file(&test_tensor_zip_path, |writer| self.test_tensor.borrow().write_npy(writer))?;
Ok(modelrdf::OutputTensorDescr{
id: self.id.clone(),
description: self.description.clone(),
axes: self.axes.clone(),
metadata: self.tensor_meta.clone(),
test_tensor: rdf::FileDescription{
source: test_tensor_zip_path.into(),
sha256: None,
Expand All @@ -104,9 +92,9 @@ pub trait VecOutputSlotExt{
impl<DATA: Borrow<NpyArray>> VecOutputSlotExt for [OutputSlot<DATA>]{
fn qual_id_axes(&self) -> impl Iterator<Item=(QualifiedAxisId, &OutputAxis)>{
self.iter()
.map(|rt_tensor_descr|{
rt_tensor_descr.axes.iter().map(|axis|{
let qual_id = QualifiedAxisId{tensor_id: rt_tensor_descr.id.clone(), axis_id: axis.id()};
.map(|slot|{
slot.tensor_meta.axes().iter().map(|axis|{
let qual_id = QualifiedAxisId{tensor_id: slot.tensor_meta.id.clone(), axis_id: axis.id()};
(qual_id, axis)
})
})
Expand Down Expand Up @@ -173,8 +161,8 @@ impl<DATA: Borrow<NpyArray>> ModelInterface<DATA> {
{
let capacity: usize = usize::from(inputs.len()) + usize::from(outputs.len());
let mut seen_tensor_ids = HashSet::<&TensorId>::with_capacity(capacity);
inputs.iter().map(|tensor_descr| &tensor_descr.id)
.chain(outputs.iter().map(|tensor_descr| &tensor_descr.id))
inputs.iter().map(|slot| &slot.tensor_meta.id)
.chain(outputs.iter().map(|slot| &slot.tensor_meta.id))
.map(|tensor_id|{
if !seen_tensor_ids.insert(tensor_id){
Err(TensorValidationError::DuplicateTensorId(tensor_id.clone()))
Expand All @@ -195,17 +183,17 @@ impl<DATA: Borrow<NpyArray>> ModelInterface<DATA> {
for slot in $slots.iter(){
let test_tensor_shape = slot.test_tensor.borrow().shape();
let mut test_tensor_dims = test_tensor_shape.iter().enumerate();
for axis in slot.axes.iter(){
for axis in slot.tensor_meta.axes().iter(){
let Some((test_tensor_dim_index, test_tensor_dim_size)) = test_tensor_dims.next() else{
return Err(TensorValidationError::MismatchedNumDimensions {
test_tensor_shape: test_tensor_shape.into(),
num_described_axes: slot.axes.len(),
num_described_axes: slot.tensor_meta.axes().len(),
});
};
if axis.size().is_none(){ // batch i guess?
continue;
};
let qual_id = QualifiedAxisId{tensor_id: slot.id.clone(), axis_id: axis.id()};
let qual_id = QualifiedAxisId{tensor_id: slot.tensor_meta.id.clone(), axis_id: axis.id()};
let resolved = size_map.get(&qual_id).unwrap();
if !resolved.is_compatible_with_extent(*test_tensor_dim_size){
return Err(TensorValidationError::IncompatibleAxis {
Expand Down
Loading

0 comments on commit 44406a4

Please sign in to comment.