NAME
AI::MXNet::CudaModule - Interface to runtime cuda kernel compile module.
DESCRIPTION
Interface to runtime cuda kernel compile module.
Compile and run CUDA code from Perl.
In CUDA 7.5, you need to prepend your kernel definitions
with 'extern "C"' to avoid name mangling::
$source = '
extern "C" __global__ void axpy(const float *x, float *y, float alpha) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
y[i] += alpha * x[i];
}
';
$module = mx->rtc->CudaModule(source);
$func = $module->get_kernel("axpy", "const float *x, float *y, float alpha");
$x = mx->nd->ones([10]), ctx=>mx->gpu(0));
$y = mx->nd->zeros([10]), ctx=>mx->gpu(0));
$func->launch([$x, $y, 3.0], mx->gpu(0), [1, 1, 1], [10, 1, 1]);
print $y->aspdl;
Starting from CUDA 8.0, you can instead export functions by name.
This also allows you to use templates::
my $source = '
template<typename DType>
__global__ void axpy(const DType *x, DType *y, DType alpha) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
y[i] += alpha * x[i];
}
';
$module = mx->rtc->CudaModule($source, exports=>['axpy<float>', 'axpy<double>']);
$func32 = $module->get_kernel("axpy<float>", "const float *x, float *y, float alpha");
$x = mx->nd->ones([10], dtype=>'float32', ctx=>mx->gpu(0));
$y = mx->nd->zeros([10], dtype=>'float32', ctx=>mx->gpu(0));
$func32->launch([$x, $y, 3.0], mx->gpu(0), [1, 1, 1], [10, 1, 1]);
print $y->aspdl;
$func64 = $module->get_kernel("axpy<double>", "const double *x, double *y, double alpha");
$x = mx->nd->ones([10], dtype=>'float64', ctx=>mx->gpu(0));
$y = mx->nd->zeros([10], dtype=>'float64', ctx=>mx->gpu(0));
$func32->launch([$x, $y, 3.0], mx->gpu(0), [1, 1, 1], [10, 1, 1]);
print $y->aspdl;
Parameters
----------
source : Str
Complete source code.
options : Str|ArrayRef[Str]
Compiler flags. For example, use "-I/usr/local/cuda/include" to
add cuda headers to include path.
exports : Str|ArrayRef[Str]
Export kernel names.
get_kernel
Get CUDA kernel from compiled module.
Parameters
----------
$name : Str
String name of the kernel.
$signature : Str
Function signature for the kernel. For example, if a kernel is
declared as::
extern "C" __global__ void axpy(const float *x, double *y, int alpha)
Then its signature should be::
const float *x, double *y, int alpha
or::
const float *, double *, int
Note that `*` in signature marks an argument as array and
`const` marks an argument as constant (input) array.
Returns
-------
AI::MXNet::CudaKernel
CUDA kernels that can be launched on GPUs.
NAME
AI::MXNet::CudaKernel - Constructs CUDA kernel.
DESCRIPTION
Constructs CUDA kernel.
Intended to be created by calling AI::MXNet::CudaModule->get_kernel only.
launch
Launch cuda kernel.
Parameters
----------
$args : ArrayRef[AI::MXNet::NDArray|Num]
List of arguments for kernel. NDArrays are expected for pointer
types (e.g. `float*`, `double*`) while numbers are expected for
non-pointer types (e.g. `int`, `float`).
$ctx : AI::MXNet::Context
The context to launch kernel on. Must be GPU context.
$grid_dims : array ref of 3 integers (CudaKernelShape)
Grid dimensions for CUDA kernel.
$block_dims : array ref of 3 integers (CudaKernelShape)
Block dimensions for CUDA kernel.
$shared_mem=0 : integer, optional
Size of dynamically allocated shared memory. Defaults to 0.