org.apache.sysml.runtime.matrix.data

Class LibMatrixDNN



  • public class LibMatrixDNN
    extends Object
    • Field Detail

      • LOG

        protected static final org.apache.commons.logging.Log LOG
      • DISPLAY_STATISTICS

        public static boolean DISPLAY_STATISTICS
    • Constructor Detail

      • LibMatrixDNN

        public LibMatrixDNN()
    • Method Detail

      • appendStatistics

        public static void appendStatistics(StringBuilder sb)
      • resetStatistics

        public static void resetStatistics()
      • conv2dBackwardData

        public static void conv2dBackwardData(MatrixBlock filter,
                                              MatrixBlock dout,
                                              MatrixBlock outputBlock,
                                              ConvolutionParameters params)
                                       throws DMLRuntimeException
        This method computes the backpropogation errors for previous layer of convolution operation
        Parameters:
        filter - filter used in conv2d
        dout - errors from next layer
        outputBlock - output errors
        params - convolution parameters
        Throws:
        DMLRuntimeException - if DMLRuntimeException occurs
      • maxpoolingBackward

        public static void maxpoolingBackward(MatrixBlock input,
                                              MatrixBlock dout,
                                              MatrixBlock outputBlock,
                                              ConvolutionParameters params,
                                              boolean performReluBackward)
                                       throws DMLRuntimeException
        This method computes the backpropogation errors for previous layer of maxpooling operation
        Parameters:
        input - input matrix
        dout - dout matrix
        outputBlock - output matrix
        params - convolution parameters
        performReluBackward - perform ReLU backward
        Throws:
        DMLRuntimeException - if DMLRuntimeException occurs
      • reluBackward

        public static void reluBackward(MatrixBlock input,
                                        MatrixBlock dout,
                                        MatrixBlock outputBlock,
                                        int numThreads)
                                 throws DMLRuntimeException
        This method computes the backpropagation errors for previous layer of relu operation
        Parameters:
        input - input matrix
        dout - errors from next layer
        outputBlock - output matrix
        numThreads - number of threads
        Throws:
        DMLRuntimeException - if DMLRuntimeException occurs
      • biasAdd

        public static void biasAdd(MatrixBlock input,
                                   MatrixBlock bias,
                                   MatrixBlock outputBlock,
                                   int numThreads)
                            throws DMLRuntimeException
        Performs the operation corresponding to the DML script: ones = matrix(1, rows=1, cols=Hout*Wout) output = input + matrix(bias %*% ones, rows=1, cols=F*Hout*Wout) This operation is often followed by conv2d and hence we have introduced bias_add(input, bias) built-in function
        Parameters:
        input - input matrix
        bias - bias matrix
        outputBlock - output matrix
        numThreads - number of threads
        Throws:
        DMLRuntimeException - if DMLRuntimeException occurs
      • biasMultiply

        public static void biasMultiply(MatrixBlock input,
                                        MatrixBlock bias,
                                        MatrixBlock outputBlock,
                                        int numThreads)
                                 throws DMLRuntimeException
        Performs the operation corresponding to the DML script: ones = matrix(1, rows=1, cols=Hout*Wout) output = input * matrix(bias %*% ones, rows=1, cols=F*Hout*Wout) This operation is often followed by conv2d and hence we have introduced bias_multiply(input, bias) built-in function
        Parameters:
        input - input matrix
        bias - bias matrix
        outputBlock - output matrix
        numThreads - number of threads
        Throws:
        DMLRuntimeException - if DMLRuntimeException occurs

Copyright © 2017 The Apache Software Foundation. All rights reserved.