if isinstance(res, cls._tensor_type):
return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
+ elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res):
+ # share the evaluation between lazy tuple elements
+ shared_args: list = [args, None]
+
+ def eager_tuple_element(a: list[Any], i: int = 0, /, **kw) -> LazyBase:
+ assert len(a) == 2
+ if a[1] is None:
+ a[1] = fn(*a[0], **kw)
+ return a[1][i]
+ return tuple(cls(meta=cls.eager_to_meta(res[i]), args=(shared_args, i), kwargs=kwargs, func=eager_tuple_element) for i in range(len(res)))
else:
del res # not needed
# non-tensor return likely relies on the contents of the args