Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

How to send multiple arguments to a function when using imap?

I want to use pool.impa to use multiprocessing for the following code:


# RSI
def predict(stock_symbol,date_confidences):
    stock_data = pd.read_csv(f'data/{stock_symbol}')


    # Calculate the RSI change from the previous row
    stock_data['rsi_change'] = stock_data['RSI'].diff()

    # Initialize the 'label' column with zeros
    stock_data['predict'] = 0

    # Set 'label' to 1 for rows where RSI is below 30 and has increased from the previous row
    stock_data.loc[(stock_data['RSI'] < 30) & (stock_data['rsi_change'] > 0), 'predict'] = 1

    # Drop the 'rsi_change' column if no longer needed
    stock_data.drop(columns=['rsi_change'], inplace=True)

    for ind, row in stock_data.iterrows():
        date_confidences[date] += [(1,row['predict'],stock_symbol,None,row['Volume'])]


if __name__ == '__main__':
    stock_symbols = [s.split('/')[-1].replace('.csv','') for s in sorted(glob.glob('../../data/*.csv'))]

    years = [1962,2023]
    current_date = datetime(years[0], 1, 1)
    date_list = []
    while current_date.year <= years[-1]:
        date_list.append(current_date.strftime('%Y-%m-%d'))
        current_date += timedelta(days=1)

    # date_confidences = collections.defaultdict(list)
    manager = Manager()
    date_confidences = manager.dict()

    for date in date_list:
        date_confidences[date] = []

    with Pool(processes=os.cpu_count()) as pool:
        args_list = [(stock_symbol, date_confidences) for stock_symbol in stock_symbols]
        list(tqdm(pool.imap(predict, args_list), total=len(stock_symbols)))

    with open('cache/date_confidences.pkl','wb') as file:
        pickle.dump(date_confidences,file)

However, I get this error:

multiprocessing.pool.RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/disk/nouri/environment/transformer3/lib/python3.9/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
TypeError: predict() missing 1 required positional argument: 'date_confidences'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/disk/nouri/stock/chatgpt/crossvalidation/day_estimator/indicators/analyze.py", line 102, in <module>
    list(tqdm(pool.imap(predict, args_list), total=len(stock_symbols)))
  File "/disk/nouri/environment/transformer3/lib/python3.9/site-packages/tqdm/std.py", line 1195, in __iter__
    for obj in iterable:
  File "/disk/nouri/environment/transformer3/lib/python3.9/multiprocessing/pool.py", line 870, in next
    raise value
TypeError: predict() missing 1 required positional argument: 'date_confidences'

In my code, I’m passing both arguments but still get the error for missing arguments. I need to pass both arguments to the function predict. Is it possible to send these two arguments with imap or I should use other methods?

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

>Solution :

You can define a new function where you pass one single argument, expand it inside the function, and send it to `predict

# RSI
def predict(stock_symbol,date_confidences):
    stock_data = pd.read_csv(f'data/{stock_symbol}')


# Calculate the RSI change from the previous row
stock_data['rsi_change'] = stock_data['RSI'].diff()

# Initialize the 'label' column with zeros
stock_data['predict'] = 0

# Set 'label' to 1 for rows where RSI is below 30 and has increased from the previous row
stock_data.loc[(stock_data['RSI'] < 30) & (stock_data['rsi_change'] > 0), 'predict'] = 1

# Drop the 'rsi_change' column if no longer needed
stock_data.drop(columns=['rsi_change'], inplace=True)

for ind, row in stock_data.iterrows():
    date_confidences[date] += [(1,row['predict'],stock_symbol,None,row['Volume'])]

def process_stock(args):
    stock_symbol, date_confidences = args
    return predict(stock_symbol, date_confidences)

if __name__ == '__main__':
     stock_symbols = [s.split('/')[-1].replace('.csv','') for s in sorted(glob.glob('../../data/*.csv'))]

     years = [1962,2023]
     current_date = datetime(years[0], 1, 1)
     date_list = []
     while current_date.year <= years[-1]:
         date_list.append(current_date.strftime('%Y-%m-%d'))
         current_date += timedelta(days=1)

     # date_confidences = collections.defaultdict(list)
     manager = Manager()
     date_confidences = manager.dict()

     for date in date_list:
         date_confidences[date] = []

     with Pool(processes=os.cpu_count()) as pool:
         args_list = [(stock_symbol, date_confidences) for stock_symbol in stock_symbols]
         list(tqdm(pool.imap(process_stock, args_list), total=len(stock_symbols)))

     with open('cache/date_confidences.pkl','wb') as file:
         pickle.dump(date_confidences,file)
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading