I have simple function that I am trying to convert into generator that return batch output. Here is code
# Simple function, working but too slow once get big list
def compute_add():
data = range(5)
cases = list(itertools.permutations(data, 2))
print(f"{cases=}")
result = []
for x, y in cases:
ans = x + y
result.append(ans)
return result
report = compute_add()
print(f"{report=}")
So in first attempt I tried it to convert into generator that returns single output
# Generator single value and working
def compute_add_generator():
data = range(5)
cases = list(itertools.permutations(data, 2))
print(f"{cases=}")
for x, y in cases:
ans = x + y
yield ans
report = []
for res in compute_add_generator():
report.append(res)
print(f"{report=}")
In the second attempt I tried it to convert to a generator that returns batch output(list)
# Generator batch output and not working
def compute_add_generator_batch(batch_size):
data = range(5)
cases = list(itertools.permutations(data, 2))
print(f"{cases=}")
res = []
for x, y in cases:
ans = x + y
if len(res) != batch_size:
res.append(ans)
continue
yield res
res = []
batch_size=3 # it can vary
for res in compute_add_generator_batch(batch_size):
print(f"{res=}")
It is giving me wrong output as it always skiping the 4th output
cases=[(0, 1), (0, 2), (0, 3), (0, 4), (1, 0), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 3), (2, 4), (3, 0), (3, 1), (3, 2), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3)]
res=[1, 2, 3]
res=[1, 3, 4]
res=[2, 3, 5]
res=[3, 4, 5]
res=[4, 5, 6]
Expected output
res=[1, 2, 3]
res=[4, 1, 3]
res=[4, 5, 2]
res=[3, 5, 6]
res=[3, 4, 5]
res=[7, 4, 5]
res=[6, 7]
I tried storing 4th result in res = [ans] still wrong output as misses last row.
What I am missing?
>Solution :
Here is an version that supports the batch_size parameter:
import itertools
def compute_add_generator(batch_size):
assert batch_size > 0
data = range(5)
batch = []
for x, y in itertools.permutations(data, 2):
ans = x + y
batch.append(ans)
if len(batch) == batch_size:
yield batch
batch = []
# the rest:
if batch:
yield batch
report = []
for res in compute_add_generator(3):
report.append(res)
print(f"{report=}")
Prints:
report = [[1, 2, 3], [4, 1, 3], [4, 5, 2], [3, 5, 6], [3, 4, 5], [7, 4, 5], [6, 7]]