Skip to content

Commit

Permalink
Add --query-layers-key option for all algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
axdanbol committed Feb 29, 2024
1 parent f4a817e commit e8af375
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 11 deletions.
60 changes: 54 additions & 6 deletions containers/azimuth/context/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class AzimuthOrganMetadata(t.TypedDict):

class AzimuthOptions(t.TypedDict):
reference_data_dir: Path
query_layers_key: t.Optional[str]


class AzimuthAlgorithm(Algorithm[AzimuthOrganMetadata, AzimuthOptions]):
Expand All @@ -37,46 +38,92 @@ def do_run(
# obs columns of dtype 'object'. As a workaround we create a
# clean matrix without obs columns on which azimuth is run
# after which the annotations are copied back to the original matrix
temp_index = self.create_temp_obs_index(data)
clean_matrix_path = Path("clean_matrix.h5ad")
clean_matrix = self.create_clean_matrix(data)
clean_matrix = self.create_clean_matrix(data, temp_index)

self.set_data_layer(clean_matrix, options["query_layers_key"])
clean_matrix.write_h5ad(clean_matrix_path)

annotated_matrix_path = self.run_azimuth_scripts(
clean_matrix_path, reference_data
)
annotated_matrix = anndata.read_h5ad(annotated_matrix_path)
self.copy_annotations(data, annotated_matrix)
self.copy_annotations(data, annotated_matrix, temp_index)

return {
"data": data,
"organ_level": metadata["organ_level"],
"prediction_column": "predicted." + metadata["prediction_column"],
}

def create_clean_matrix(self, matrix: anndata.AnnData) -> anndata.AnnData:
def create_temp_obs_index(self, matrix: anndata.AnnData) -> pandas.Index:
"""Creates a new index by adding a prefix to each index name.
Used as a workaround for: https://github.com/satijalab/azimuth/issues/178
and https://github.com/satijalab/azimuth/issues/138
Args:
matrix (anndata.AnnData): Original data
Returns:
pandas.Index: A new index
"""
return matrix.obs.index.map(lambda name: f"QUERY:{name}")

def create_clean_matrix(
self,
matrix: anndata.AnnData,
temp_index: pandas.Index,
) -> anndata.AnnData:
"""Creates a copy of the data with all observation columns removed.
Args:
matrix (anndata.AnnData): Original data
temp_index (pandas.Index): Temporary index generated by `create_temp_obs_index`
Returns:
anndata.AnnData: Cleaned data
"""
clean_obs = pandas.DataFrame(index=matrix.obs.index)
clean_obs = pandas.DataFrame(index=temp_index)
clean_matrix = matrix.copy()
clean_matrix.obs = clean_obs

return clean_matrix

def set_data_layer(
self, matrix: anndata.AnnData, query_layers_key: t.Optional[str]
) -> None:
"""Set the data layer to use for annotating.
Args:
matrix (anndata.AnnData): Matrix to update
query_layers_key (t.Optional[str]): A layer name or 'raw'
"""
if query_layers_key == "raw":
matrix.X = matrix.raw.X
elif query_layers_key is not None:
matrix.X = matrix.layers[query_layers_key].copy()

def copy_annotations(
self, matrix: anndata.AnnData, annotated_matrix: anndata.AnnData
self,
matrix: anndata.AnnData,
annotated_matrix: anndata.AnnData,
temp_index: pandas.Index,
) -> None:
"""Copies annotations from one matrix to another.
Args:
matrix (anndata.AnnData): Matrix to copy to
annotated_matrix (anndata.AnnData): Matrix to copy from
temp_index (pandas.Index): Temporary index generated by `create_temp_obs_index`
"""
matrix.obs = matrix.obs.join(annotated_matrix.obs, rsuffix="_azimuth")
matrix.obs = matrix.obs.merge(
annotated_matrix.obs,
how="left",
left_on=temp_index,
right_index=True,
suffixes=(None, "_azimuth"),
)

def run_azimuth_scripts(self, matrix_path: Path, reference_data: Path) -> str:
"""Creates a subprocess running the Azimuth annotation R script.
Expand Down Expand Up @@ -156,6 +203,7 @@ def _get_arg_parser():
required=True,
help="Path to directory with reference data",
)
parser.add_argument("--query-layers-key", help="Data layer to use")

return parser

Expand Down
5 changes: 5 additions & 0 deletions containers/azimuth/options.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@ fields:
label: Directory with reference data directories
inputBinding:
prefix: --reference-data-dir
queryLayersKey:
type: string?
label: Data layer to use
inputBinding:
prefix: --query-layers-key
17 changes: 17 additions & 0 deletions containers/celltypist/context/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class CelltypistOrganMetadata(t.TypedDict):

class CelltypistOptions(t.TypedDict):
ensemble_lookup: Path
query_layers_key: t.Optional[str]


class CelltypistAlgorithm(Algorithm[CelltypistOrganMetadata, CelltypistOptions]):
Expand All @@ -31,6 +32,7 @@ def do_run(
) -> RunResult:
"""Annotate data using celltypist."""
data = scanpy.read_h5ad(matrix)
self.set_data_layer(data, options["query_layers_key"])
data = self.normalize(data)
data, var_names = self.normalize_var_names(data, options)
data = celltypist.annotate(
Expand All @@ -40,6 +42,20 @@ def do_run(

return {"data": data, "organ_level": metadata["model"].replace(".", "_")}

def set_data_layer(
self, matrix: scanpy.AnnData, query_layers_key: t.Optional[str]
) -> None:
"""Set the data layer to use for annotating.
Args:
matrix (anndata.AnnData): Matrix to update
query_layers_key (t.Optional[str]): A layer name or 'raw'
"""
if query_layers_key == "raw":
matrix.X = matrix.raw.X
elif query_layers_key is not None:
matrix.X = matrix.layers[query_layers_key].copy()

def normalize(self, data: scanpy.AnnData) -> scanpy.AnnData:
"""Normalizes data according to celltypist requirements.
Expand Down Expand Up @@ -115,6 +131,7 @@ def _get_arg_parser():
default="/ensemble-lookup.csv",
help="Ensemble id to gene name csv",
)
parser.add_argument("--query-layers-key", help="Data layer to use")

return parser

Expand Down
7 changes: 6 additions & 1 deletion containers/celltypist/options.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
type: record
name: options
label: Celltypist specific options
fields: {}
fields:
queryLayersKey:
type: string?
label: Data layer to use
inputBinding:
prefix: --query-layers-key
4 changes: 1 addition & 3 deletions containers/popv/context/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,7 @@ def _get_arg_parser():
required=True,
help="Path to models directory",
)
parser.add_argument(
"--query-layers-key", required=True, help="Name of layer with raw counts"
)
parser.add_argument("--query-layers-key", help="Name of layer with raw counts")
parser.add_argument("--prediction-mode", default="fast", help="Prediction mode")
parser.add_argument(
"--cell-ontology-dir",
Expand Down
2 changes: 1 addition & 1 deletion containers/popv/options.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fields:
inputBinding:
prefix: --models-dir
queryLayersKey:
type: string
type: string?
inputBinding:
prefix: --query-layers-key
predictionMode:
Expand Down

0 comments on commit e8af375

Please sign in to comment.