public class LibMatrixDNNHelper extends Object
Modifier and Type | Class and Description |
---|---|
static class |
LibMatrixDNNHelper.ReluBackward
Performs the operation: (X gt 0) * dout
|
Constructor and Description |
---|
LibMatrixDNNHelper() |
Modifier and Type | Method and Description |
---|---|
static ArrayList<Callable<Long>> |
getConv2dBackwardDataWorkers(ConvolutionParameters params)
Factory method that returns list of callable tasks for performing conv2d backward data
|
static ArrayList<Callable<Long>> |
getConv2dBackwardFilterWorkers(ConvolutionParameters params)
Factory method that returns list of callable tasks for performing conv2d backward filter
|
static ArrayList<Callable<Long>> |
getConv2dWorkers(ConvolutionParameters params)
Factory method that returns list of callable tasks for performing conv2d
|
static ArrayList<Callable<Long>> |
getMaxPoolingBackwardWorkers(ConvolutionParameters params,
boolean performReluBackward)
Factory method that returns list of callable tasks for performing maxpooling backward operation
|
static ArrayList<Callable<Long>> |
getMaxPoolingWorkers(ConvolutionParameters params)
Factory method that returns list of callable tasks for performing maxpooling operation
|
static ArrayList<Callable<Long>> |
getReluBackwardWorkers(ConvolutionParameters params)
Factory method that returns list of callable tasks for performing relu backward operation
|
public static ArrayList<Callable<Long>> getMaxPoolingWorkers(ConvolutionParameters params) throws DMLRuntimeException
params
- convolution parametersDMLRuntimeException
- if error occurspublic static ArrayList<Callable<Long>> getMaxPoolingBackwardWorkers(ConvolutionParameters params, boolean performReluBackward) throws DMLRuntimeException
params
- convolution parametersperformReluBackward
- whether to perform ReLU backwardDMLRuntimeException
- if error occurspublic static ArrayList<Callable<Long>> getReluBackwardWorkers(ConvolutionParameters params) throws DMLRuntimeException
params
- convolution parametersDMLRuntimeException
- if error occurspublic static ArrayList<Callable<Long>> getConv2dWorkers(ConvolutionParameters params) throws DMLRuntimeException
params
- convolution parametersDMLRuntimeException
- if error occurspublic static ArrayList<Callable<Long>> getConv2dBackwardFilterWorkers(ConvolutionParameters params) throws DMLRuntimeException
params
- convolution parametersDMLRuntimeException
- if error occurspublic static ArrayList<Callable<Long>> getConv2dBackwardDataWorkers(ConvolutionParameters params) throws DMLRuntimeException
params
- convolution parametersDMLRuntimeException
- if error occursCopyright © 2017 The Apache Software Foundation. All rights reserved.