The London Perl and Raku Workshop takes place on 26th Oct 2024. If your company depends on Perl, please consider sponsoring and/or attending.

NAME

AI::MXNet::RNN - Functions for constructing recurrent neural networks.

SYNOPSIS

DESCRIPTION

Functions for constructing recurrent neural networks.

save_rnn_checkpoint

Save checkpoint for model using RNN cells.
Unpacks weight before saving.

Parameters
----------
cells : AI::MXNet::RNN::Cell or array ref of AI::MXNet::RNN::Cell
    The RNN cells used by this symbol.
prefix : str
    Prefix of model name.
epoch : int
    The epoch number of the model.
symbol : Symbol
    The input symbol
arg_params : hash ref of str to AI::MXNet::NDArray
    Model parameter, hash ref of name to NDArray of net's weights.
aux_params : hash ref of str to AI::MXNet::NDArray
    Model parameter, hash ref of name to NDArray of net's auxiliary states.

Notes
-----
- prefix-symbol.json will be saved for symbol.
- prefix-epoch.params will be saved for parameters.

load_rnn_checkpoint

Load model checkpoint from file.
Pack weights after loading.

Parameters
----------
cells : AI::MXNet::RNN::Cell or ir array ref of AI::MXNet::RNN::Cell
    The RNN cells used by this symbol.
prefix : str
    Prefix of model name.
epoch : int
    Epoch number of model we would like to load.

Returns
-------
symbol : Symbol
    The symbol configuration of computation network.
arg_params : hash ref of str to NDArray
    Model parameter, dict of name to NDArray of net's weights.
aux_params : hash ref of str to NDArray
    Model parameter, dict of name to NDArray of net's auxiliary states.

Notes
-----
- symbol will be loaded from prefix-symbol.json.
- parameters will be loaded from prefix-epoch.params.

do_rnn_checkpoint

Make a callback to checkpoint Module to prefix every epoch.
unpacks weights used by cells before saving.

Parameters
----------
cells : subclass of RNN::Cell
    RNN cells used by this module.
prefix : str
    The file prefix to checkpoint to
period : int
    How many epochs to wait before checkpointing. Default is 1.

Returns
-------
callback : function
    The callback function that can be passed as iter_end_callback to fit.