diff --git a/tfdeterminism/enable_determinism.py b/tfdeterminism/enable_determinism.py index 354bbd4..96ecec5 100644 --- a/tfdeterminism/enable_determinism.py +++ b/tfdeterminism/enable_determinism.py @@ -21,7 +21,7 @@ import tensorflow as tf -from .patch import _patch_bias_add +from .patch import _patch_bias_add, _patch_fused_softmax_cross_entropy from .utils import _Version as Version def _enable_determinism(seed=None): @@ -31,7 +31,7 @@ def _enable_determinism(seed=None): Call this method either before or after explicitly importing TensorFlow, but always before constructing any graphs. - This function cannot address all possible sources of non-determinism. Please + This function cannot address all possible sources of non-determinism. Please see further instructions at https://github.com/NVIDIA/tensorflow-determinism to understand how to use it in a larger deterministic context. @@ -52,7 +52,7 @@ def _enable_determinism(seed=None): _patch_bias_add() if in_ngc_cont and ngc_vers.at_least('19.06') or tf_vers.at_least('2.1'): os.environ['TF_DETERMINISTIC_OPS'] = '1' + # TODO: Add patch crossentropy here as well? Issue seems to still be present on tf 2.1, 2.2 if in_ngc_cont and ngc_vers.at_least('19.06') or tf_vers.at_least('1.14'): - # Apply the fused softmax/cross-entropy patch here - pass - # TODO: Add other recipe items + _patch_fused_softmax_cross_entropy() + # TODO: Add other recipe items diff --git a/tfdeterminism/patch.py b/tfdeterminism/patch.py index d4d3118..7da63c0 100644 --- a/tfdeterminism/patch.py +++ b/tfdeterminism/patch.py @@ -70,7 +70,12 @@ def _patch(): if re.match("(1\.(14|15)|2\.0)", tf_version): os.environ['TF_CUDNN_DETERMINISTIC'] = '1' _patch_bias_add() - # Apply the fused softmax/cross-entropy patch here + _patch_fused_softmax_cross_entropy() + print("TensorFlow version %s has been patched " + "using tfdeterminism version %s" % + (tf_version, __version__), file=sys.stderr) + elif re.match("2\.1|2\.2", tf_version): + _patch_fused_softmax_cross_entropy() print("TensorFlow version %s has been patched " "using tfdeterminism version %s" % (tf_version, __version__), file=sys.stderr) @@ -78,6 +83,7 @@ def _patch(): raise TypeError("tfdeterminism: No patch available " "for version %s of TensorFlow" % tf_version) + def _patch_bias_add(): tf.nn.bias_add = _new_bias_add_1_14 # access via public API nn.bias_add = _new_bias_add_1_14 # called from tf.keras.layers.convolutional.Conv @@ -136,3 +142,109 @@ def _new_bias_add_1_14(value, bias, data_format=None, name=None): value, array_ops.reshape(bias, broadcast_shape), name=name) else: # data_format == 'NHWC' or data_format == None return math_ops.add(value, bias, name=name) + + +def _patch_fused_softmax_cross_entropy(): + # Non-sparse + tf.nn.softmax_cross_entropy_with_logits = _new_softmax_cross_entropy_with_logits # access via public API + nn.softmax_cross_entropy_with_logits = _new_softmax_cross_entropy_with_logits # called from tf.keras.layers.convolutional.Conv + nn_ops.softmax_cross_entropy_with_logits = _new_softmax_cross_entropy_with_logits # called from tests + + # Sparse + tf.nn.sparse_softmax_cross_entropy_with_logits = _new_sparse_softmax_cross_entropy_with_logits # access via public API + nn.sparse_softmax_cross_entropy_with_logits = _new_sparse_softmax_cross_entropy_with_logits # called from tf.keras.layers.convolutional.Conv + nn_ops.sparse_softmax_cross_entropy_with_logits = _new_sparse_softmax_cross_entropy_with_logits # called from tests + +# The original, pre-patched method can be viewed at +# https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/python/ops/nn_ops.py#L3182 +def _new_softmax_cross_entropy_with_logits(labels, logits, axis=-1, name=None): + """Computes softmax cross entropy between `logits` and `labels`. + Measures the probability error in discrete classification tasks in which the + classes are mutually exclusive (each entry is in exactly one class). For + example, each CIFAR-10 image is labeled with one and only one label: an image + can be a dog or a truck, but not both. + **NOTE:** While the classes are mutually exclusive, their probabilities + need not be. All that is required is that each row of `labels` is + a valid probability distribution. If they are not, the computation of the + gradient will be incorrect. + If using exclusive `labels` (wherein one and only + one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`. + **WARNING:** This op expects unscaled logits, since it performs a `softmax` + on `logits` internally for efficiency. Do not call this op with the + output of `softmax`, as it will produce incorrect results. + A common use case is to have logits and labels of shape + `[batch_size, num_classes]`, but higher dimensions are supported, with + the `dim` argument specifying the class dimension. + Backpropagation will happen only into `logits`. To calculate a cross entropy + loss that allows backpropagation into both `logits` and `labels`, see + `tf.nn.softmax_cross_entropy_with_logits_v2`. + **Note that to avoid confusion, it is required to pass only named arguments to + this function.** + Args: + _sentinel: Used to prevent positional parameters. Internal, do not use. + labels: Each vector along the class dimension should hold a valid + probability distribution e.g. for the case in which labels are of shape + `[batch_size, num_classes]`, each row of `labels[i]` must be a valid + probability distribution. + logits: Per-label activations, typically a linear output. These activation + energies are interpreted as unnormalized log probabilities. + dim: The class dimension. Defaulted to -1 which is the last dimension. + name: A name for the operation (optional). + axis: Alias for dim. + Returns: + A `Tensor` that contains the softmax cross entropy loss. Its type is the + same as `logits` and its shape is the same as `labels` except that it does + not have the last dimension of `labels`. + """ + raise NotImplementedError() + + +# The original, pre-patched method can be viewed at +# https://github.com/tensorflow/tensorflow/blob/v1.14.0/tensorflow/python/ops/nn_ops.py#L2628 +def _new_sparse_softmax_cross_entropy_with_logits( + _sentinel=None, # pylint: disable=invalid-name + labels=None, + logits=None, + name=None): + """Computes sparse softmax cross entropy between `logits` and `labels`. + Measures the probability error in discrete classification tasks in which the + classes are mutually exclusive (each entry is in exactly one class). For + example, each CIFAR-10 image is labeled with one and only one label: an image + can be a dog or a truck, but not both. + **NOTE:** For this operation, the probability of a given label is considered + exclusive. That is, soft classes are not allowed, and the `labels` vector + must provide a single specific index for the true class for each row of + `logits` (each minibatch entry). For soft softmax classification with + a probability distribution for each entry, see + `softmax_cross_entropy_with_logits_v2`. + **WARNING:** This op expects unscaled logits, since it performs a `softmax` + on `logits` internally for efficiency. Do not call this op with the + output of `softmax`, as it will produce incorrect results. + A common use case is to have logits of shape + `[batch_size, num_classes]` and have labels of shape + `[batch_size]`, but higher dimensions are supported, in which + case the `dim`-th dimension is assumed to be of size `num_classes`. + `logits` must have the dtype of `float16`, `float32`, or `float64`, and + `labels` must have the dtype of `int32` or `int64`. + **Note that to avoid confusion, it is required to pass only named arguments to + this function.** + Args: + _sentinel: Used to prevent positional parameters. Internal, do not use. + labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-1}]` (where `r` is rank of + `labels` and result) and dtype `int32` or `int64`. Each entry in `labels` + must be an index in `[0, num_classes)`. Other values will raise an + exception when this op is run on CPU, and return `NaN` for corresponding + loss and gradient rows on GPU. + logits: Per-label activations (typically a linear output) of shape + `[d_0, d_1, ..., d_{r-1}, num_classes]` and dtype `float16`, `float32`, or + `float64`. These activation energies are interpreted as unnormalized log + probabilities. + name: A name for the operation (optional). + Returns: + A `Tensor` of the same shape as `labels` and of the same type as `logits` + with the softmax cross entropy loss. + Raises: + ValueError: If logits are scalars (need to have rank >= 1) or if the rank + of the labels is not equal to the rank of the logits minus one. + """ + raise NotImplementedError()