Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 633726060
  • Loading branch information
Jake VanderPlas authored and Prompt Tuning Team committed May 15, 2024
1 parent 92ad4ab commit 126378b
Show file tree
Hide file tree
Showing 108 changed files with 113 additions and 113 deletions.
2 changes: 1 addition & 1 deletion prompt_tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/extended/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/extended/architectures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/extended/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/extended/runs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/extended/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/models/decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/models/sizes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/runs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/configs/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/c4.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/features.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/glue.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/glue_transfer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/metrics_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/postprocessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/postprocessors_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/preprocessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/preprocessors_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/qa.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/show_tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/summarization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/super_glue.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/tasks_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/data/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/masks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/masks_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/multitask_prompts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/multitask_prompts_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/perceptron/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/perceptron/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/perceptron/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/perceptron/data/tasks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/perceptron/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
10 changes: 5 additions & 5 deletions prompt_tuning/extended/perceptron/train/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,7 +35,7 @@


def fold(batch: Mapping[str, Array]) -> Mapping[str, Array]:
return jax.tree_map(_fold, batch)
return jax.tree.map(_fold, batch)


def _fold(arr: Array) -> Array:
Expand All @@ -51,7 +51,7 @@ def shape_fold(shape: Sequence[int]) -> Tuple[int, ...]:


def unfold(batch: Mapping[str, Array], batch_size: int) -> Mapping[str, Array]:
return jax.tree_map(functools.partial(_unfold, batch_size), batch)
return jax.tree.map(functools.partial(_unfold, batch_size), batch)


def _unfold(batch_size: int, arr: Array) -> Array:
Expand Down Expand Up @@ -98,8 +98,8 @@ def _compute_logits( # pytype: disable=signature-mismatch
batch_size = batch["encoder_input_tokens"].shape[0]
# [B, C, T] -> [B * C, T]
folded_batch = fold(batch)
logging.info("Batch: %s", jax.tree_map(jnp.shape, batch))
logging.info("Folded Batch: %s", jax.tree_map(jnp.shape, folded_batch))
logging.info("Batch: %s", jax.tree.map(jnp.shape, batch))
logging.info("Folded Batch: %s", jax.tree.map(jnp.shape, folded_batch))
# [B * C, T, V]
logits = super()._compute_logits(params, folded_batch, dropout_rng, mutable)
# [B * C, T]
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/train/ia3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
4 changes: 2 additions & 2 deletions prompt_tuning/extended/train/ia3_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,7 +37,7 @@ def test_reshpaed_attention_equal_default(self):

params = ia3.IA3(ia3_init=nn.initializers.uniform()).init(
jax.random.PRNGKey(0), default_input)
attention_params = jax.tree_map(
attention_params = jax.tree.map(
lambda x: jnp.reshape(x, (heads, kv)), params)

default_result = jax.jit(ia3.IA3().apply)(params, default_input)
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/train/multitask_partitioning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/train/multitask_prompts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/train/multitask_prompts_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/train/per_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/train/per_layer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/extended/train/wayward.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/masks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/masks_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/prompts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion prompt_tuning/prompts_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Google.
# Copyright 2024 Google.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down

0 comments on commit 126378b

Please sign in to comment.