tensorflow: RaggedTensor casting bug
version 2.0.0-alpha. Nested RaggedTensor are cast to int64 without apparent reason:
with regular tensors (ok):
>>> tf.constant([[1]], dtype=tf.int8)
<tf.Tensor: id=98, shape=(1, 1), dtype=int8, numpy=array([[1]], dtype=int8)>
with nested RaggedTensor (not ok):
>>> tf.ragged.constant([[1]], dtype=tf.int8)
tf.RaggedTensor(values=tf.Tensor([1], shape=(1,), dtype=int8), row_splits=tf.Tensor([0 1], shape=(2,), dtype=int64))
Also they can not be used to create generators even with dtype=int64. The following code leads to:
The expected type was int64, but the yielded element was <tf.RaggedTensor [[6]]>.
class LineGenerator(object):
def get_next_line(self):
while True:
out = [[6]]
yield tf.ragged.constant(out, dtype=tf.int64)
class Dataset(object):
def __init__(self, generator=LineGenerator()):
self.next_element = self.build_iterator(generator)
def build_iterator(self, gen: LineGenerator):
dataset = tf.data.Dataset.from_generator(gen.get_next_line,output_types = tf.int64)
#some other code...
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 19 (11 by maintainers)
@stekiri This is a bug with Dataset.from_generator. #37400 is a PR to fix it, though I think it’s been stalled for a little while on making sure it doesn’t break a test.
tf.data support for RaggedTensors was added by 5fe90dc.
I know. I’m not yet complete. I’m testing out various ways. Will create a PR as soon as I’m done.
On Sun, 7 Apr 2019, 7:18 pm ARozental, notifications@github.com wrote:
On a Second thought, This issue can be solved by modifying the from_generator() function. @dynamicwebpaige @alextp can I work on it?