diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index de421b2970b9..a0e80bc9fbc1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -38,6 +38,8 @@ import org.nd4j.autodiff.samediff.config.OutputConfig; import org.nd4j.autodiff.samediff.internal.*; import org.nd4j.autodiff.samediff.ops.*; +import org.nd4j.autodiff.samediff.optimize.GraphOptimizer; +import org.nd4j.autodiff.samediff.optimize.OptimizationConfig; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.base.Preconditions; import org.nd4j.evaluation.IEvaluation; @@ -109,6 +111,7 @@ @Slf4j public class SameDiff extends SDBaseOps { protected static final String GRAD_FN_KEY = "grad"; + protected static final String OPTIMIZED_FN_KEY = "optimized"; //Fields for graph structure and execution @Getter @@ -118,7 +121,9 @@ public class SameDiff extends SDBaseOps { @Getter private final Map sessions = new ConcurrentHashMap<>(); //Key: thread ID + @Getter @Setter //TODO shouldn't be in public API private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true); + @Getter @Setter //TODO shouldn't be in public API private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true); private final Map> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them @@ -238,6 +243,11 @@ public SDBitwise bitwise(){ return bitwise; } + @Setter @Getter + private boolean allowOptimization = true; + private String[] optimizedWRT = null; + + private Map sameDiffFunctionInstances; private Table fieldVariableResolutionMapping; @@ -2554,6 +2564,23 @@ protected Map batchOutputHelper(Map placehol activeListeners.add(l); } + if(allowOptimization){ + if(!sameDiffFunctionInstances.containsKey(OPTIMIZED_FN_KEY) || optimizedWRT == null || !Arrays.equals(optimizedWRT, outputs)){ + //Need to create optimized version + + SameDiff sd = optimize(Arrays.asList(outputs)); + sameDiffFunctionInstances.put(OPTIMIZED_FN_KEY, sd); + + + //TODO clean up old version optimized SameDiff if necessary + } + SameDiff optimized = sameDiffFunctionInstances.get(OPTIMIZED_FN_KEY); + if(optimized.isAllowOptimization()) + optimized.setAllowOptimization(false); //Prevent recursive optimizations + + return optimized.batchOutputHelper(placeholders, activeListeners, operation, outputs); + } + for (Listener l : activeListeners) { l.operationStart(this, operation); } @@ -5863,4 +5890,10 @@ public String generateDistinctCustomVariableName(String base){ return base + "_" + inc; } + + protected SameDiff optimize(List withRespectToOutputs){ + SameDiff sd = GraphOptimizer.optimize(this, withRespectToOutputs); + sd.setAllowOptimization(false); //Prevent recursive optimization attempts when output is called + return sd; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/OptimizedGraphArrayHolder.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/OptimizedGraphArrayHolder.java new file mode 100644 index 000000000000..0b2470676ffc --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/array/OptimizedGraphArrayHolder.java @@ -0,0 +1,78 @@ +package org.nd4j.autodiff.samediff.array; + +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.function.Supplier; + +import java.util.*; + +public class OptimizedGraphArrayHolder implements ArrayHolder { + + private final ArrayHolder underlyingHolder; + private final Map> functions; + + public OptimizedGraphArrayHolder(ArrayHolder underlyingHolder){ + this.underlyingHolder = underlyingHolder; + this.functions = new HashMap<>(); + } + + public void setFunction(String name, Supplier fn){ + if(underlyingHolder.hasArray(name)) + underlyingHolder.removeArray(name); + functions.put(name, fn); + } + + @Override + public boolean hasArray(String name) { + return functions.containsKey(name) || underlyingHolder.hasArray(name); + } + + @Override + public INDArray getArray(String name) { + if(functions.containsKey(name)) + return functions.get(name).get(); + return underlyingHolder.getArray(name); + } + + @Override + public void setArray(String name, INDArray array) { + Preconditions.checkState(!functions.containsKey(name), "Cannot set array when existing array is only accessible via a function"); + underlyingHolder.setArray(name, array); + } + + @Override + public INDArray removeArray(String name) { + Supplier s = functions.remove(name); + if(s != null) + return s.get(); + return underlyingHolder.removeArray(name); + } + + @Override + public int size() { + return underlyingHolder.size() + functions.size(); + } + + @Override + public void initFrom(ArrayHolder arrayHolder) { + underlyingHolder.initFrom(arrayHolder); + } + + @Override + public Collection arrayNames() { + Set set = new HashSet<>(); + set.addAll(underlyingHolder.arrayNames()); + set.addAll(functions.keySet()); + return set; + } + + @Override + public void rename(String from, String to) { + if(functions.containsKey(from)) { + functions.put(to, functions.remove(from)); + } else { + underlyingHolder.rename(from, to); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java new file mode 100644 index 000000000000..96d1ad574f0f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/GraphOptimizer.java @@ -0,0 +1,146 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.samediff.optimize; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.debug.OptimizationDebugger; +import org.nd4j.autodiff.samediff.optimize.optimizations.*; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +/** + * + * @author Alex Black + */ +@Slf4j +public class GraphOptimizer { + + public static List defaultOptimizations(){ + return Arrays.asList( + new UnusedFunctionOptimizations(), + new ConstantFunctionOptimizations(), + new IdentityFunctionOptimizations(), + new ShapeFunctionOptimizations(), + new UnusedFunctionOptimizations(), + new CuDNNFunctionOptimizations() + ); + } + + public static SameDiff optimize(SameDiff graph, String... requiredOutputs){ + return optimize(graph, Arrays.asList(requiredOutputs)); + } + + public static SameDiff optimize(SameDiff graph, List requiredOutputs){ + return optimize(graph, requiredOutputs, defaultOptimizations()); + } + + public static SameDiff optimize(SameDiff graph, List requiredOutputs, List optimizations) { + return optimize(graph, requiredOutputs, optimizations, null); + } + + public static SameDiff optimize(SameDiff graph, List requiredOutputs, List optimizations, OptimizationDebugger debugger){ + //TODO Use required outputs - strip unnecessary graph components + + SameDiff sd = graph.dup(); + + ArrayHolder cArr = sd.getConstantArrays(); + ArrayHolder vArr = sd.getVariablesArrays(); + + OptimizationHelper h = new OptimizationHelper(graph, new OptimizationConfig()); //TODO defaults for config + + for( int i=0; i<3; i++ ) { //Run multiple times - one run isn't enough, as some more optimizations may need to be applied to the output of earlier optimizations + for (OptimizerSet s : optimizations) { + List l = s.getOptimizers(); + for(Optimizer o : l ){ + Collection startingOps = new ArrayList<>(sd.getOps().values()); //Create list to avoid concurrent modification exception + for(SameDiffOp op : startingOps) { + //Because ops might disappear from previous optimization steps, we need to check if the previous op + // still exists when iterating... + if(!sd.getOps().containsKey(op.getName())) + continue; + + if(debugger != null) + debugger.beforeOptimizationCheck(sd, op, o); + + boolean applied = o.checkAndApply(sd, h, op, cArr, vArr); + if(applied) { + log.info("Operation was applied: {}", o); + } + + if(debugger != null) + debugger.afterOptimizationsCheck(sd, op, o, applied); + } + } + } + } + + int constBefore = 0; + int constAfter = 0; + int varBefore = 0; + int varAfter = 0; + int arrBefore = 0; + int arrAfter = 0; + + for(SDVariable v : graph.variables()){ + switch(v.getVariableType()){ + case VARIABLE: + varBefore++; + break; + case CONSTANT: + constBefore++; + break; + case ARRAY: + arrBefore++; + break; + case PLACEHOLDER: + break; + } + } + + for(SDVariable v : sd.variables()){ + switch(v.getVariableType()){ + case VARIABLE: + varAfter++; + break; + case CONSTANT: + constAfter++; + break; + case ARRAY: + arrAfter++; + break; + case PLACEHOLDER: + break; + } + } + + + log.info("Total variables: {} before, {} after", graph.getVariables().size(), sd.getVariables().size()); + log.info("Constant variables: {} before, {} after", constBefore, constAfter); + log.info("Array type variables: {} before, {} after", arrBefore, arrAfter); + log.info("Variable type variables: {} before, {} after", varBefore, varAfter); + log.info("Ops: {} before, {} after", graph.getOps().size(), sd.getOps().size()); + + return sd; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationConfig.java new file mode 100644 index 000000000000..ebc1b036eca8 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationConfig.java @@ -0,0 +1,22 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.samediff.optimize; + +import java.util.Properties; + +public class OptimizationConfig extends Properties { + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationHelper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationHelper.java new file mode 100644 index 000000000000..a8470b1926a7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizationHelper.java @@ -0,0 +1,69 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.samediff.optimize; + +import lombok.Getter; +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.array.OptimizedGraphArrayHolder; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.function.Supplier; + +import java.util.Properties; + +public class OptimizationHelper { + + private final SameDiff originalGraph; + @Getter + private final Properties properties; + private boolean setConstantHolder = false; + private boolean setVariableHolder = false; + + public OptimizationHelper(SameDiff originalGraph, Properties properties){ + this.originalGraph = originalGraph; + this.properties = properties; + } + + public OptimizationHelper arrayRecoveryFunction(String arrayName, Supplier fn){ + SDVariable v = originalGraph.getVariable(arrayName); + Preconditions.checkState(v.getVariableType() == VariableType.VARIABLE || v.getVariableType() == VariableType.CONSTANT, + "Can only set an array recovery function for a variable or a constant"); + + if(v.getVariableType() == VariableType.VARIABLE){ + ArrayHolder h = originalGraph.getVariablesArrays(); + if(!setVariableHolder){ + originalGraph.setVariablesArrays(new OptimizedGraphArrayHolder(h)); + h = originalGraph.getVariablesArrays(); + setVariableHolder = true; + } + ((OptimizedGraphArrayHolder)h).setFunction(arrayName, fn); + } else { + ArrayHolder h = originalGraph.getConstantArrays(); + if(!setConstantHolder){ + originalGraph.setConstantArrays(new OptimizedGraphArrayHolder(h)); + h = originalGraph.getConstantArrays(); + setConstantHolder = true; + } + ((OptimizedGraphArrayHolder)h).setFunction(arrayName, fn); + } + + return this; + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java new file mode 100644 index 000000000000..2411562603d4 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/Optimizer.java @@ -0,0 +1,39 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.samediff.optimize; + +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; + +import java.util.Properties; + +/** + * @author Alex Black + */ +public interface Optimizer { + + /** + * @param sd Current SameDiff instance to optimize + * @param helper Helper class for optimization + * @param op Operation to check for optimization + * @param constantArrays Array holder for constant arrays + * @param variablesArrays Array holder for variable arrays + * @return True if the optimization was applied + */ + boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays); + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizerSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizerSet.java new file mode 100644 index 000000000000..6c3bca83df05 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/OptimizerSet.java @@ -0,0 +1,28 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.samediff.optimize; + +import java.util.List; + +/** + * + * @author Alex Black + */ +public interface OptimizerSet { + + List getOptimizers(); + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/debug/OptimizationDebugger.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/debug/OptimizationDebugger.java new file mode 100644 index 000000000000..db4663b1abc3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/debug/OptimizationDebugger.java @@ -0,0 +1,33 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.samediff.optimize.debug; + +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.Optimizer; + +/** + * Used as a listener for + * + * @author Alex Black + */ +public interface OptimizationDebugger { + + void beforeOptimizationCheck(SameDiff sd, SameDiffOp op, Optimizer o); + + void afterOptimizationsCheck(SameDiff sd, SameDiffOp op, Optimizer o, boolean wasApplied); + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/BaseOptimizerSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/BaseOptimizerSet.java new file mode 100644 index 000000000000..7a60745ed3f4 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/BaseOptimizerSet.java @@ -0,0 +1,67 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.autodiff.samediff.optimize.OptimizerSet; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.List; + +/** + * + * @author Alex Black + */ +@Slf4j +public abstract class BaseOptimizerSet implements OptimizerSet { + + + @Override + public List getOptimizers() { + Method[] methods = this.getClass().getDeclaredMethods(); + List out = new ArrayList<>(methods.length); + for(Method m : methods){ + int modifiers = m.getModifiers(); + Class retType = m.getReturnType(); + if(retType != null && Modifier.isPublic(modifiers) && Optimizer.class.isAssignableFrom(retType) ){ + try { + Optimizer o = (Optimizer) m.invoke(null); + out.add(o); + } catch (IllegalAccessException | InvocationTargetException e) { + log.warn("Could not create optimizer from method: {}", m, e); + } + } + } + + Class[] declaredClasses = this.getClass().getDeclaredClasses(); + for(Class c : declaredClasses){ + int modifiers = c.getModifiers(); + if(Modifier.isPublic(modifiers) && !Modifier.isAbstract(modifiers) && Optimizer.class.isAssignableFrom(c)){ + try{ + out.add((Optimizer) c.newInstance()); + } catch (IllegalAccessException | InstantiationException e) { + log.warn("Could not create optimizer from inner class: {}", c, e); + } + } + } + + return out; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java new file mode 100644 index 000000000000..5f3d0a2af09e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ConstantFunctionOptimizations.java @@ -0,0 +1,111 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.OptimizationHelper; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.List; + +/** + * This set of optimizations looks for functions that are applied to constants, and "pre executes" them, so they don't have + * to be calculated (returning the same value) on each run. + * + * @author Alex Black + */ +public class ConstantFunctionOptimizations extends BaseOptimizerSet { + + public static final String CONSTANT_FN_FOLDING_MAX_SIZE = "optimizer.constants.function.max.output.size"; + public static final long CONSTANT_FN_FOLDING_MAX_SIZE_DEFAULT = 4 * 1024 * 1024; //4MB + + public static class FoldConstantFunctions implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + //TODO This function needs to check for non-deterministic ops - i.e., random ops - and not apply the optimization to these + + List in = op.getInputsToOp(); + if (in == null || in.isEmpty()) + return false; + for (String s : in) { + if (!sd.getVariable(s).isConstant()) + return false; + } + + long maxSizeToApply = Long.parseLong(helper.getProperties().getProperty(CONSTANT_FN_FOLDING_MAX_SIZE, String.valueOf(CONSTANT_FN_FOLDING_MAX_SIZE_DEFAULT))); + //Apply the optimization: + DifferentialFunction df = op.getOp(); + df.clearArrays(); + for (int i = 0; i < in.size(); i++) { + String s = in.get(i); + INDArray arr = sd.getVariable(s).getArr(); + if (df instanceof CustomOp) { + ((CustomOp) df).addInputArgument(arr); + } else { + if (i == 0) + ((Op) df).setX(arr); + else + ((Op) df).setY(arr); + } + } + + INDArray[] outputs; + if (df instanceof CustomOp) { + CustomOp o = (CustomOp) df; + Nd4j.exec(o); + outputs = new INDArray[o.numOutputArguments()]; + for (int j = 0; j < outputs.length; j++) { + outputs[j] = o.getOutputArgument(j); + } + } else { + Op o = (Op) df; + Nd4j.exec(o); + outputs = new INDArray[]{o.z()}; + } + long sizeCount = 0; + for (INDArray i : outputs) { + if (!i.dataType().isNumerical()) + continue; + sizeCount += i.length() * i.dataType().width(); + } + + if (sizeCount > maxSizeToApply) + return false; + + //Convert outputs to constants + List outputNames = op.getOutputsOfOp(); + for(int i=0; i inputs = op.getInputsToOp(); + String wArgName = inputs.get(1); + + //Step 1 - replace activations + if(!activationsCorrect) { + String inArgName = inputs.get(0); + SDVariable in = sd.getVariable(inArgName); + //Replace [in -> Conv2d(NCHW) -> out] with [in -> permute -> Conv2d(NHWC) -> permute -> out] + String newName = in.name() + "_cudnn_nchw_to_nhwc"; + OptimizationUtils.replaceOpInputsWith(sd, in.name(), newName); + SDVariable nhwc = in.permute(0, 2, 3, 1).rename(newName); //NCHW to NHWC + + SDVariable outNhwc = sd.getVariable(op.getOutputsOfOp().get(0)); + String newName2 = outNhwc.name() + "_cudnn_nhwc_to_nchw"; + SDVariable outNchw = outNhwc.permute(0, 3, 1, 2).rename(newName2); //NHWC to NCHW + + OptimizationUtils.replaceOpInputsWith(sd, outNhwc.name(), outNchw.name()); + + c2d.getConfig().isNHWC(true); + } + + //Step 2 - replace YXIO weights (default) with OYXI weights + //We'll just add a permute here, and let other optimizer steps fix the (variable -> permute -> op ==> permutedVariable -> op) part + if(!weightsCorrect) { + SDVariable w = sd.getVariable(wArgName); + String newWname = w.name() + "_cudnn_yxio_to_oyxi"; + OptimizationUtils.replaceOpInputsWith(sd, w.name(), newWname); + SDVariable wPermuted = w.permute(3, 0, 1, 2).rename(newWname); + + + //TODO once config supports weight layout, set it here + } + + + return true; + } + } + + /* + TODO: Also do pooling2d, batchnorm, etc + */ + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java new file mode 100644 index 000000000000..be1c1dc09d5c --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/IdentityFunctionOptimizations.java @@ -0,0 +1,58 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.OptimizationHelper; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; + +import java.util.Properties; + +public class IdentityFunctionOptimizations extends BaseOptimizerSet { + + /** + * Remove permute(0,1,2,...,rank-1) as this is a no-op + */ + public static class RemoveIdentityPermute implements Optimizer { + + @Override + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + return false; + } + } + + /** + * Remove identity(x) + */ + public static class RemoveIdentityOps implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + if(op.getOp() instanceof Identity){ + String inName = op.getInputsToOp().get(0); + String outputName = op.getOutputsOfOp().get(0); + OptimizationUtils.removeOp(sd, op.getName()); + OptimizationUtils.replaceOpInputsWith(sd, outputName, inName); + OptimizationUtils.removeVariable(sd, outputName); + return true; + } + + return false; + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java new file mode 100644 index 000000000000..4c6503c6da67 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/OptimizationUtils.java @@ -0,0 +1,65 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +public class OptimizationUtils { + + private OptimizationUtils(){ } + + public static void replaceOpInputsWith(SameDiff sd, @NonNull String replaceInput, @NonNull String newInput){ + if(replaceInput.equals(newInput)) + return; + + //Update op input structure: Replace all instances replaceInput->X with newInput->X + Collection ops = sd.getOps().values(); + for(SameDiffOp o : ops){ + List l = o.getInputsToOp(); + while(l != null && l.contains(replaceInput)){ + int idx = l.indexOf(replaceInput); + l.set(idx, newInput); + } + } + + //Update variable structure + Variable v = sd.getVariables().get(replaceInput); + Variable v2 = sd.getVariables().get(newInput); + //NOTE: this only works if we carefully control the order in which replaceOpInputsWith is called! + v2.setInputsForOp(v.getInputsForOp()); + v.setInputsForOp(new ArrayList()); + } + + public static void removeOp(@NonNull SameDiff sd, @NonNull String opToRemove){ + SameDiffOp op = sd.getOps().remove(opToRemove); + for(String s : op.getInputsToOp()){ + Variable v = sd.getVariables().get(s); + v.getInputsForOp().remove(op.getName()); + } + } + + public static void removeVariable(@NonNull SameDiff sd, @NonNull String varToRemove){ + sd.getVariables().remove(varToRemove); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ShapeFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ShapeFunctionOptimizations.java new file mode 100644 index 000000000000..f45b92bce372 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/ShapeFunctionOptimizations.java @@ -0,0 +1,91 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; +import org.nd4j.autodiff.samediff.optimize.OptimizationHelper; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.linalg.api.ops.impl.shape.Permute; + +import java.util.ArrayList; +import java.util.List; + +public class ShapeFunctionOptimizations extends BaseOptimizerSet { + + /** + * Fuse [permute1 -> permute2 -> ... -> permuteN] into a single permute op, + * as long as the intermediate permute outputs aren't needed for another op + */ + public static class FuseChainedPermutes implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + if(!(op.getOp() instanceof Permute)) + return false; + + List inputs = op.getInputsToOp(); + String input = inputs.get(0); + + List toFuse = new ArrayList<>(); + toFuse.add(op.getName()); + String currInput = input; + while(currInput != null){ + Variable v = sd.getVariables().get(currInput); + //In order to fuse permute operations, we require: + // (a) the intermediate variable is ONLY needed by the next permute + // (b) the permute dimensions are constant, + + if(v.getInputsForOp().size() > 1) + break; + } + + if(toFuse.size() > 1){ + //Fuse the permute ops + +// return true; + return false; + } + + + return false; + } + } + + /** + * Fuse [reshape1 -> reshape2 -> ... -> reshapeN] into a single reshape op, + * as long as the intermediate reshape ops aren't needed for another op + */ + public static class FuseChainedReshapes implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + return false; + } + } + + /** + * Fuse [concat(concat(concat(x,y,dim=D), z, dim=D), a, dim=D)] into a single concat op, concat(x,y,z,a, dim=D) + * As long as the intermediate outputs aren't needed elsewhere + */ + public static class FuseChainedConcatOps implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + return false; + } + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/UnusedFunctionOptimizations.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/UnusedFunctionOptimizations.java new file mode 100644 index 000000000000..96d3d34de95d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/optimize/optimizations/UnusedFunctionOptimizations.java @@ -0,0 +1,68 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.autodiff.samediff.optimize.optimizations; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.ArrayHolder; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.internal.Variable; +import org.nd4j.autodiff.samediff.optimize.OptimizationHelper; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.function.Supplier; + +import java.util.ArrayList; +import java.util.List; + +@Slf4j +public class UnusedFunctionOptimizations extends BaseOptimizerSet { + + public static class RemoveUnusedConstants implements Optimizer { + @Override + public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) { + //TODO check this once _per graph_ not per op + List variables = new ArrayList<>(sd.getVariables().values()); + boolean anyRemoved = false; + for(Variable v : variables){ + if(v.getVariable().getVariableType() == VariableType.CONSTANT){ + List inputFor = v.getInputsForOp(); + if(inputFor == null || inputFor.isEmpty()){ + //This constant isn't used... + + //TODO let's put these on disk instead of keeping them in memory... + final INDArray arr = v.getVariable().getArr(); + helper.arrayRecoveryFunction(v.getName(), new Supplier() { + @Override + public INDArray get() { + return arr; + } + }); + + sd.getVariables().remove(v.getName()); + log.info("Removed unused constant: {}", v.getName()); + anyRemoved = true; + } + } + } + return anyRemoved; + } + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java new file mode 100644 index 000000000000..ecdd77afc8f3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestOptimization.java @@ -0,0 +1,134 @@ +package org.nd4j.autodiff.optimization; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.nd4j.autodiff.optimization.util.OptTestConfig; +import org.nd4j.autodiff.optimization.util.OptimizationTestUtil; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.optimize.GraphOptimizer; +import org.nd4j.autodiff.samediff.optimize.optimizations.ConstantFunctionOptimizations; +import org.nd4j.autodiff.samediff.optimize.optimizations.IdentityFunctionOptimizations; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.Assert.*; + +public class TestOptimization extends BaseNd4jTest { + + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + + public TestOptimization(Nd4jBackend backend) { + super(backend); + } + + @Override + public char ordering() { + return 'c'; + } + + @Override + public long getTimeoutMilliseconds() { + return 1_000_000_000L; + } + + @Test + public void testConstantOpFolding(){ + //We expect 2 things in this test: + //(a) the output of add(constant, constant) is pre-calculated and itself becomes a constant + //(b) the + + + SameDiff sd = SameDiff.create(); + SDVariable c = sd.constant("c", Nd4j.scalar(1.0)); + SDVariable c2 = c.add("add", 1); + SDVariable v = sd.var("variable", Nd4j.scalar(1.0)); + SDVariable out = v.sub("out", c2); + + SameDiff copy = sd.dup(); + + SameDiff optimized = GraphOptimizer.optimize(sd, "out"); + assertEquals(3, optimized.getVariables().size()); //"add", "variable", "out" -> "c" should be removed + assertEquals(VariableType.CONSTANT, optimized.getVariable("add").getVariableType()); + assertEquals(1, optimized.getOps().size()); + assertEquals("subtract", optimized.getOps().values().iterator().next().getName()); + + assertFalse(optimized.hasVariable("c")); + + assertEquals(sd.outputSingle(Collections.emptyMap(), "out"), optimized.outputSingle(Collections.emptyMap(), "out")); + + //Check the + + //Check that the original can be saved and loaded, and still gives the same results + + } + + @Test + public void testConstantOpFolding2(){ + //We expect 2 things in this test: + //(a) the output of add(constant, constant) is pre-calculated and itself becomes a constant + //(b) the + + + SameDiff sd = SameDiff.create(); + SDVariable c = sd.constant("c", Nd4j.scalar(1.0)); + SDVariable c2 = c.add("add", 1); + SDVariable v = sd.var("variable", Nd4j.scalar(1.0)); + SDVariable out = v.sub("out", c2); + + OptTestConfig conf = OptTestConfig.builder() + .original(sd) + .outputs(Collections.singletonList("out")) + .mustApply(sd.getVariables().get("add").getOutputOfOp(), ConstantFunctionOptimizations.FoldConstantFunctions.class) + .build(); + + SameDiff optimized = OptimizationTestUtil.testOptimization(conf); + assertEquals(3, optimized.getVariables().size()); //"add", "variable", "out" -> "c" should be removed + assertEquals(VariableType.CONSTANT, optimized.getVariable("add").getVariableType()); + assertEquals(1, optimized.getOps().size()); + assertEquals("subtract", optimized.getOps().values().iterator().next().getName()); + + assertFalse(optimized.hasVariable("c")); + + assertEquals(sd.outputSingle(Collections.emptyMap(), "out"), optimized.outputSingle(Collections.emptyMap(), "out")); + + } + + @Test + public void testIdentityRemoval(){ + + //Ensure that optimizer is actually used when calling output methods: + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3)); + SDVariable i1 = sd.identity(in); + SDVariable i2 = sd.identity(w); + SDVariable i3 = sd.identity(b); + SDVariable out = sd.nn.softmax("out", sd.identity(i1.mmul(i2).add(i3))); + + OptTestConfig conf = OptTestConfig.builder() + .original(sd) + .outputs(Collections.singletonList("out")) + .placeholder("in", Nd4j.rand(DataType.FLOAT, 5, 4)) + .mustApply(sd.getVariables().get(i1.name()).getOutputOfOp(), IdentityFunctionOptimizations.RemoveIdentityOps.class) + .mustApply(sd.getVariables().get(i2.name()).getOutputOfOp(), IdentityFunctionOptimizations.RemoveIdentityOps.class) + .mustApply(sd.getVariables().get(i3.name()).getOutputOfOp(), IdentityFunctionOptimizations.RemoveIdentityOps.class) + .build(); + + SameDiff optimized = OptimizationTestUtil.testOptimization(conf); + assertEquals(3, optimized.getOps().size()); + assertFalse(optimized.hasVariable(i1.name())); + assertFalse(optimized.hasVariable(i2.name())); + assertFalse(optimized.hasVariable(i3.name())); + assertTrue(optimized.hasVariable("out")); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestSeamlessOptimization.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestSeamlessOptimization.java new file mode 100644 index 000000000000..ec1914af1171 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/TestSeamlessOptimization.java @@ -0,0 +1,137 @@ +package org.nd4j.autodiff.optimization; + +import lombok.Data; +import org.junit.Test; +import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.listeners.At; +import org.nd4j.autodiff.listeners.BaseListener; +import org.nd4j.autodiff.listeners.Operation; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.BaseNd4jTest; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.reduce.Mmul; +import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; +import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.factory.Nd4jBackend; + +import java.util.*; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class TestSeamlessOptimization extends BaseNd4jTest { + + public TestSeamlessOptimization(Nd4jBackend backend) { + super(backend); + } + + + @Test + public void testOutput(){ + + //Ensure that optimizer is actually used when calling output methods: + SameDiff sd = SameDiff.create(); + SDVariable in = sd.placeHolder("in", DataType.FLOAT, -1, 4); + SDVariable w = sd.var("w", Nd4j.rand(DataType.FLOAT, 4, 3)); + SDVariable b = sd.var("b", Nd4j.rand(DataType.FLOAT, 3)); + + SDVariable i1 = sd.identity(in); + SDVariable i2 = sd.identity(w); + SDVariable i3 = sd.identity(b); + + SDVariable out = sd.nn.softmax("out", sd.identity(i1.mmul(i2).add(i3))); + + RecordOpsListener l = new RecordOpsListener(); + sd.setListeners(new AssertNoOpsOfTypeListener(Identity.class), l); + + Map ph = Collections.singletonMap("in", Nd4j.rand(DataType.FLOAT, 10, 4)); + + for( int i=0; i<3; i++ ) { + l.ops.clear(); + + switch (i){ + case 0: + sd.outputSingle(ph, "out"); + break; + case 1: + sd.output(ph, "out"); + break; + case 2: + sd.batchOutput().output("out") + .input("in", ph.get("in")) + .outputSingle(); + break; + } + + + List> expClasses = Arrays.asList(Mmul.class, AddOp.class, SoftMax.class); + assertEquals(3, l.ops.size()); + for (int j = 0; j < 3; j++) { + assertEquals(expClasses.get(j), l.ops.get(j).getOp().getClass()); + } + + } + } + + @Test + public void testDifferentOutputs(){ + //Test when the user requests different outputs instead + } + + @Test + public void testGraphModification(){ + //User modifies the graph -> should reoptimize? + + fail("Not yet implemented"); + } + + public static class AssertNoOpsOfTypeListener extends BaseListener { + private List> list; + + public AssertNoOpsOfTypeListener(Class... c) { + Preconditions.checkState(c != null && c.length > 0, "No classes provided"); + this.list = Arrays.asList(c); + } + + @Override + public boolean isActive(Operation operation) { + return true; + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + if(list.contains(op.getOp().getClass())){ + throw new IllegalStateException("Encountered unexpected class: " + op.getOp().getClass().getName()); + } + } + } + + @Data + public static class RecordOpsListener extends BaseListener { + + private List ops = new ArrayList<>(); + + @Override + public boolean isActive(Operation operation) { + return true; + } + + @Override + public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, INDArray[] outputs) { + ops.add(op); + } + } + + + @Override + public char ordering() { + return 'c'; + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptTestConfig.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptTestConfig.java new file mode 100644 index 000000000000..57a099a72bea --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptTestConfig.java @@ -0,0 +1,91 @@ +package org.nd4j.autodiff.optimization.util; + +import lombok.Builder; +import lombok.Data; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.autodiff.samediff.optimize.OptimizerSet; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.io.File; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@Data +public class OptTestConfig { + + private SameDiff original; + private Map placeholders; + private List outputs; + private File tempFolder; + private Map> mustApply; + private List optimizerSets; + + public static Builder builder(){ + return new Builder(); + } + + public static class Builder { + + private SameDiff original; + private Map placeholders; + private List outputs; + private File tempFolder; + private Map> mustApply; + private List optimizerSets; + + public Builder original(SameDiff sd){ + original = sd; + return this; + } + + public Builder placeholder(String ph, INDArray arr){ + if(placeholders == null) + placeholders = new HashMap<>(); + placeholders.put(ph, arr); + return this; + } + + public Builder placeholders(Map map){ + placeholders = map; + return this; + } + + public Builder outputs(String... outputs){ + this.outputs = Arrays.asList(outputs); + return this; + } + + public Builder outputs(List outputs){ + this.outputs = outputs; + return this; + } + + public Builder mustApply(String opName, Class optimizerClass){ + if(mustApply == null) + mustApply = new HashMap<>(); + mustApply.put(opName, optimizerClass); + return this; + } + + public Builder optimizerSets(List list){ + this.optimizerSets = list; + return this; + } + + public OptTestConfig build(){ + OptTestConfig c = new OptTestConfig(); + c.original = original; + c.placeholders = placeholders; + c.outputs = outputs; + c.tempFolder = tempFolder; + c.mustApply = mustApply; + c.optimizerSets = optimizerSets; + return c; + } + + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationRecordingDebugger.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationRecordingDebugger.java new file mode 100644 index 000000000000..74c134f4e9a6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationRecordingDebugger.java @@ -0,0 +1,28 @@ +package org.nd4j.autodiff.optimization.util; + +import lombok.Getter; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.internal.SameDiffOp; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.autodiff.samediff.optimize.debug.OptimizationDebugger; + +import java.util.HashMap; +import java.util.Map; + +public class OptimizationRecordingDebugger implements OptimizationDebugger { + + @Getter + private Map applied = new HashMap<>(); + + @Override + public void beforeOptimizationCheck(SameDiff sd, SameDiffOp op, Optimizer o) { + //No op + } + + @Override + public void afterOptimizationsCheck(SameDiff sd, SameDiffOp op, Optimizer o, boolean wasApplied) { + if(wasApplied){ + applied.put(op.getName(), o); + } + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationTestUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationTestUtil.java new file mode 100644 index 000000000000..0e0cade8a745 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/optimization/util/OptimizationTestUtil.java @@ -0,0 +1,101 @@ +package org.nd4j.autodiff.optimization.util; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.samediff.VariableType; +import org.nd4j.autodiff.samediff.optimize.GraphOptimizer; +import org.nd4j.autodiff.samediff.optimize.Optimizer; +import org.nd4j.autodiff.samediff.optimize.OptimizerSet; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.io.File; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * TODO: + * - Add ability to track which optimization functions exactly were applied! + */ +public class OptimizationTestUtil { + + private OptimizationTestUtil(){ } + + public static SameDiff testOptimization(OptTestConfig config){ + Preconditions.checkNotNull(config.getTempFolder(), "Temp folder should be specified before running test"); + + List optimizerSets = config.getOptimizerSets(); + if(optimizerSets == null) + optimizerSets = GraphOptimizer.defaultOptimizations(); + OptimizationRecordingDebugger debugger = new OptimizationRecordingDebugger(); + + // + Map ph = config.getPlaceholders(); + List outputs = config.getOutputs(); + SameDiff original = config.getOriginal(); + SameDiff copy = original.dup(); + SameDiff optimized = GraphOptimizer.optimize(original, outputs, optimizerSets, debugger); + + //Check that SOMETHING changed in the optimized - number of constants, variables, or ops; or the settings for ops; or the values of some arrays + //TODO + boolean sameNumConst = original.getConstantArrays().size() == optimized.getConstantArrays().size(); + boolean sameNumVars = original.getVariablesArrays().size() == optimized.getVariablesArrays().size(); + boolean sameNumSDVars = original.getVariables().size() == optimized.getVariables().size(); + boolean sameNumOps = original.getOps().size() == optimized.getOps().size(); + + if(sameNumConst && sameNumVars && sameNumSDVars && sameNumOps){ + + + throw new IllegalStateException("Did not detect any changes to the graph structure after optimization (but check is AS YET WIP)"); + } + + //Check that optimizations we expected to be applied were in fact applied: + Map> mustApply = config.getMustApply(); + Map applied = debugger.getApplied(); + for(String s : mustApply.keySet()){ + assertTrue("Expected optimizer of type " + mustApply.get(s).getSimpleName() + " to be applied to op " + s, + applied.containsKey(s)); + } + + + //Second: check that they all produce the same + //TODO this won't work for random ops! + Map origOut = original.output(ph, outputs); + Map copyOut = copy.output(ph, outputs); + Map optimizedOut = optimized.output(ph, outputs); + + assertEquals(copyOut, origOut); + assertEquals(copyOut, optimizedOut); + + File f = new File(config.getTempFolder(), "optimized.sd"); + optimized.save(f, true); + + SameDiff loaded = SameDiff.load(f, true); + Map loadedOut = loaded.output(ph, outputs); + assertEquals(copyOut, loadedOut); + + //TODO add support for training checks! + //This is especially important for updaters... if we permute the weights, we should permute the updater state also + + //Check that nothing has changed (from the user API perspective) for the original graph + //i.e., + for(SDVariable v : copy.variables()){ + SDVariable ov = original.getVariable(v.name()); + + assertEquals(v.dataType(), ov.dataType()); + assertEquals(v.getVariableType(), ov.getVariableType()); + if(v.getVariableType() == VariableType.CONSTANT || v.getVariableType() == VariableType.VARIABLE){ + INDArray arrCopy = v.getArr(); + INDArray arrOrig = ov.getArr(); + assertEquals(arrCopy, arrOrig); + } + + } + + return optimized; + } + +}