Follow

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use
Contact

Unable to use numpy.dot with numba

I am getting errors trying to run numpy.dot with numba. It seems to be supported (eg: numpy: Faster np.dot/ multiply(element-wise multiplication) when one array is the same) but eg this code gives me the following error (it runs fine if I remove the njit part)

Code:

import numpy as np
import numba

@numba.njit()
def tst_dot():
    a = np.array([[1, 0], [0, 1]])
    b = np.array([[4, 1], [2, 2]])

    return np.dot(a, b)

print(tst_dot())

Error:

MEDevel.com: Open-source for Healthcare and Education

Collecting and validating open-source software for healthcare, education, enterprise, development, medical imaging, medical records, and digital pathology.

Visit Medevel

No implementation of function Function(<function dot at 0x00000280CC542EF0>) found for signature:
 
 >>> dot(array(int64, 2d, C), array(int64, 2d, C))
 
There are 4 candidate implementations:
      - Of which 2 did not match due to:
      Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
        With argument(s): '(array(int64, 2d, C), array(int64, 2d, C))':
       Rejected as the implementation raised a specific error:
         TypingError: Failed in nopython mode pipeline (step: native lowering)
       Failed in nopython mode pipeline (step: nopython frontend)
       No implementation of function Function(<function dot at 0x00000280CC542EF0>) found for signature:
        
        >>> dot(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))
        
       There are 4 candidate implementations:
             - Of which 2 did not match due to:
             Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
               With argument(s): '(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))':
              Rejected as the implementation raised a specific error:
                TypingError: too many positional arguments
         raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typing\templates.py:784
             - Of which 2 did not match due to:
             Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
               With argument(s): '(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))':
              Rejected as the implementation raised a specific error:
                LoweringError: Failed in nopython mode pipeline (step: native lowering)
              unsupported dtype for <BLAS function>()
              
              File "venv\lib\site-packages\numba\np\linalg.py", line 817:
                          def codegen(context, builder, sig, args):
                              <source elided>
              
                      return lambda left, right, out: _impl(left, right, out)
                      ^
              
              During: lowering "$10call_function.4 = call $2load_deref.0(left, right, out, func=$2load_deref.0, args=[Var(left, linalg.py:817), Var(right, linalg.py:817), Var(out, linalg.py:817)], kws=(), vararg=None, varkwarg=None, target=None)" at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (817)
         raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\errors.py:837
       
       During: resolving callee type: Function(<function dot at 0x00000280CC542EF0>)
       During: typing of call at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (460)
       
       
       File "venv\lib\site-packages\numba\np\linalg.py", line 460:
           def dot_impl(a, b):
               <source elided>
               out = np.empty((m, n), a.dtype)
               return np.dot(a, b, out)
               ^
       
       During: lowering "$8call_function.3 = call $2load_deref.0(left, right, func=$2load_deref.0, args=[Var(left, linalg.py:582), Var(right, linalg.py:582)], kws=(), vararg=None, varkwarg=None, target=None)" at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (582)
  raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typeinfer.py:1086
      - Of which 2 did not match due to:
      Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
        With argument(s): '(array(int64, 2d, C), array(int64, 2d, C))':
       Rejected as the implementation raised a specific error:
         TypingError: missing a required argument: 'out'
  raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typing\templates.py:784

During: resolving callee type: Function(<function dot at 0x00000280CC542EF0>)
During: typing of call at C:\Users\a_che\PycharmProjects\minCovTarget\tst4.py (164)


File "tst4.py", line 164:
def tst_dot(a, b):
    <source elided>

    return np.dot(a, b)
    ^

I have tried adding out=None as a third argument (even though it is meant to be optional) but it didn’t help. I was expecting the same result as if I was not using numba.

>Solution :

The docs say:

Basic linear algebra is supported on 1-D and 2-D contiguous arrays of floating-point and complex numbers:

  • numpy.dot()

However, your two arrays contain integers. Note indeed, the error message:

dot(array(int64, 2d, C), array(int64, 2d, C))

Hence, the trick is to change the dtype:

import numpy as np
import numba

@numba.njit()
def tst_dot():
    a = np.array([[1, 0], [0, 1]], dtype=np.float32)
    b = np.array([[4, 1], [2, 2]], dtype=np.float32)

    return np.dot(a, b)

print(tst_dot())

[[4. 1.]
 [2. 2.]]
Add a comment

Leave a Reply

Keep Up to Date with the Most Important News

By pressing the Subscribe button, you confirm that you have read and are agreeing to our Privacy Policy and Terms of Use

Discover more from Dev solutions

Subscribe now to keep reading and get access to the full archive.

Continue reading