-
Notifications
You must be signed in to change notification settings - Fork 75.2k
Open
Labels
Description
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
source
TensorFlow version
tf_nightly 2.21.0-dev20260112 (v1.12.1-135334-g20c4833e3b8)
Custom code
Yes
OS platform and distribution
Ubuntu 24.04.3 LTS
Mobile device
No response
Python version
3.11.14
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
12.9
GPU model and memory
NVIDIA GeForce RTX 5090
Current behavior?
When calling tf.slice with a zero-sized size argument on a dynamically-shaped tensor inside a jit_compile=True function:
- Eager mode succeeds and returns an empty tensor of shape (0, 1), consistent with expected TensorFlow semantics.
- TF-XLA mode (@tf.function(jit_compile=True)) crashes with a heap corruption error
tested with command: CUDA_VISIBLE_DEVICES="" python3 test_code.py
Standalone code to reproduce the issue
import numpy as np
import tensorflow as tf
class Model(tf.keras.Model):
@tf.function(jit_compile=True)
def call(self, inputs, training):
# tf.where produces dynamic shape (None, 3)
indices = tf.where(inputs)
# top_k on dynamic shape -> (None, 2)
values, _ = tf.math.top_k(indices, k=2, sorted=False)
# slice with zero size on dynamic shape -> (0, 1)
return tf.slice(values, tf.constant([0, 0]), tf.constant([0, 1]))
m = Model()
inp = np.arange(0, 20).reshape([2, 2, 5]).astype(np.int64)
y = m(inp, training=False)Relevant log output
free(): invalid next size (fast)
Aborted (core dumped)Reactions are currently unavailable