Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import org.nd4j.autodiff.samediff.config.OutputConfig;
import org.nd4j.autodiff.samediff.internal.*;
import org.nd4j.autodiff.samediff.ops.*;
import org.nd4j.autodiff.samediff.optimize.GraphOptimizer;
import org.nd4j.autodiff.samediff.optimize.OptimizationConfig;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
Expand Down Expand Up @@ -109,6 +111,7 @@
@Slf4j
public class SameDiff extends SDBaseOps {
protected static final String GRAD_FN_KEY = "grad";
protected static final String OPTIMIZED_FN_KEY = "optimized";

//Fields for graph structure and execution
@Getter
Expand All @@ -118,7 +121,9 @@ public class SameDiff extends SDBaseOps {
@Getter
private final Map<Long, InferenceSession> sessions = new ConcurrentHashMap<>(); //Key: thread ID

@Getter @Setter //TODO shouldn't be in public API
private ArrayHolder constantArrays = new ThreadSafeArrayHolder(true);
@Getter @Setter //TODO shouldn't be in public API
private ArrayHolder variablesArrays = new ThreadSafeArrayHolder(true);
private final Map<Long, Map<String, INDArray>> placeholdersPerThread = new ConcurrentHashMap<>(); //Placeholders for each thread - if the user sets them

Expand Down Expand Up @@ -238,6 +243,11 @@ public SDBitwise bitwise(){
return bitwise;
}

@Setter @Getter
private boolean allowOptimization = true;
private String[] optimizedWRT = null;


private Map<String, SameDiff> sameDiffFunctionInstances;

private Table<String, String, String> fieldVariableResolutionMapping;
Expand Down Expand Up @@ -2554,6 +2564,23 @@ protected Map<String, INDArray> batchOutputHelper(Map<String, INDArray> placehol
activeListeners.add(l);
}

if(allowOptimization){
if(!sameDiffFunctionInstances.containsKey(OPTIMIZED_FN_KEY) || optimizedWRT == null || !Arrays.equals(optimizedWRT, outputs)){
//Need to create optimized version

SameDiff sd = optimize(Arrays.asList(outputs));
sameDiffFunctionInstances.put(OPTIMIZED_FN_KEY, sd);


//TODO clean up old version optimized SameDiff if necessary
}
SameDiff optimized = sameDiffFunctionInstances.get(OPTIMIZED_FN_KEY);
if(optimized.isAllowOptimization())
optimized.setAllowOptimization(false); //Prevent recursive optimizations

return optimized.batchOutputHelper(placeholders, activeListeners, operation, outputs);
}

for (Listener l : activeListeners) {
l.operationStart(this, operation);
}
Expand Down Expand Up @@ -5863,4 +5890,10 @@ public String generateDistinctCustomVariableName(String base){

return base + "_" + inc;
}

protected SameDiff optimize(List<String> withRespectToOutputs){
SameDiff sd = GraphOptimizer.optimize(this, withRespectToOutputs);
sd.setAllowOptimization(false); //Prevent recursive optimization attempts when output is called
return sd;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package org.nd4j.autodiff.samediff.array;

import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.function.Supplier;

import java.util.*;

public class OptimizedGraphArrayHolder implements ArrayHolder {

private final ArrayHolder underlyingHolder;
private final Map<String, Supplier<INDArray>> functions;

public OptimizedGraphArrayHolder(ArrayHolder underlyingHolder){
this.underlyingHolder = underlyingHolder;
this.functions = new HashMap<>();
}

public void setFunction(String name, Supplier<INDArray> fn){
if(underlyingHolder.hasArray(name))
underlyingHolder.removeArray(name);
functions.put(name, fn);
}

@Override
public boolean hasArray(String name) {
return functions.containsKey(name) || underlyingHolder.hasArray(name);
}

@Override
public INDArray getArray(String name) {
if(functions.containsKey(name))
return functions.get(name).get();
return underlyingHolder.getArray(name);
}

@Override
public void setArray(String name, INDArray array) {
Preconditions.checkState(!functions.containsKey(name), "Cannot set array when existing array is only accessible via a function");
underlyingHolder.setArray(name, array);
}

@Override
public INDArray removeArray(String name) {
Supplier<INDArray> s = functions.remove(name);
if(s != null)
return s.get();
return underlyingHolder.removeArray(name);
}

@Override
public int size() {
return underlyingHolder.size() + functions.size();
}

@Override
public void initFrom(ArrayHolder arrayHolder) {
underlyingHolder.initFrom(arrayHolder);
}

@Override
public Collection<String> arrayNames() {
Set<String> set = new HashSet<>();
set.addAll(underlyingHolder.arrayNames());
set.addAll(functions.keySet());
return set;
}

@Override
public void rename(String from, String to) {
if(functions.containsKey(from)) {
functions.put(to, functions.remove(from));
} else {
underlyingHolder.rename(from, to);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff.optimize;

import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.optimize.debug.OptimizationDebugger;
import org.nd4j.autodiff.samediff.optimize.optimizations.*;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

/**
*
* @author Alex Black
*/
@Slf4j
public class GraphOptimizer {

public static List<OptimizerSet> defaultOptimizations(){
return Arrays.<OptimizerSet>asList(
new UnusedFunctionOptimizations(),
new ConstantFunctionOptimizations(),
new IdentityFunctionOptimizations(),
new ShapeFunctionOptimizations(),
new UnusedFunctionOptimizations(),
new CuDNNFunctionOptimizations()
);
}

public static SameDiff optimize(SameDiff graph, String... requiredOutputs){
return optimize(graph, Arrays.asList(requiredOutputs));
}

public static SameDiff optimize(SameDiff graph, List<String> requiredOutputs){
return optimize(graph, requiredOutputs, defaultOptimizations());
}

public static SameDiff optimize(SameDiff graph, List<String> requiredOutputs, List<OptimizerSet> optimizations) {
return optimize(graph, requiredOutputs, optimizations, null);
}

public static SameDiff optimize(SameDiff graph, List<String> requiredOutputs, List<OptimizerSet> optimizations, OptimizationDebugger debugger){
//TODO Use required outputs - strip unnecessary graph components

SameDiff sd = graph.dup();

ArrayHolder cArr = sd.getConstantArrays();
ArrayHolder vArr = sd.getVariablesArrays();

OptimizationHelper h = new OptimizationHelper(graph, new OptimizationConfig()); //TODO defaults for config

for( int i=0; i<3; i++ ) { //Run multiple times - one run isn't enough, as some more optimizations may need to be applied to the output of earlier optimizations
for (OptimizerSet s : optimizations) {
List<Optimizer> l = s.getOptimizers();
for(Optimizer o : l ){
Collection<SameDiffOp> startingOps = new ArrayList<>(sd.getOps().values()); //Create list to avoid concurrent modification exception
for(SameDiffOp op : startingOps) {
//Because ops might disappear from previous optimization steps, we need to check if the previous op
// still exists when iterating...
if(!sd.getOps().containsKey(op.getName()))
continue;

if(debugger != null)
debugger.beforeOptimizationCheck(sd, op, o);

boolean applied = o.checkAndApply(sd, h, op, cArr, vArr);
if(applied) {
log.info("Operation was applied: {}", o);
}

if(debugger != null)
debugger.afterOptimizationsCheck(sd, op, o, applied);
}
}
}
}

int constBefore = 0;
int constAfter = 0;
int varBefore = 0;
int varAfter = 0;
int arrBefore = 0;
int arrAfter = 0;

for(SDVariable v : graph.variables()){
switch(v.getVariableType()){
case VARIABLE:
varBefore++;
break;
case CONSTANT:
constBefore++;
break;
case ARRAY:
arrBefore++;
break;
case PLACEHOLDER:
break;
}
}

for(SDVariable v : sd.variables()){
switch(v.getVariableType()){
case VARIABLE:
varAfter++;
break;
case CONSTANT:
constAfter++;
break;
case ARRAY:
arrAfter++;
break;
case PLACEHOLDER:
break;
}
}


log.info("Total variables: {} before, {} after", graph.getVariables().size(), sd.getVariables().size());
log.info("Constant variables: {} before, {} after", constBefore, constAfter);
log.info("Array type variables: {} before, {} after", arrBefore, arrAfter);
log.info("Variable type variables: {} before, {} after", varBefore, varAfter);
log.info("Ops: {} before, {} after", graph.getOps().size(), sd.getOps().size());

return sd;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff.optimize;

import java.util.Properties;

public class OptimizationConfig extends Properties {

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff.optimize;

import lombok.Getter;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.array.OptimizedGraphArrayHolder;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.function.Supplier;

import java.util.Properties;

public class OptimizationHelper {

private final SameDiff originalGraph;
@Getter
private final Properties properties;
private boolean setConstantHolder = false;
private boolean setVariableHolder = false;

public OptimizationHelper(SameDiff originalGraph, Properties properties){
this.originalGraph = originalGraph;
this.properties = properties;
}

public OptimizationHelper arrayRecoveryFunction(String arrayName, Supplier<INDArray> fn){
SDVariable v = originalGraph.getVariable(arrayName);
Preconditions.checkState(v.getVariableType() == VariableType.VARIABLE || v.getVariableType() == VariableType.CONSTANT,
"Can only set an array recovery function for a variable or a constant");

if(v.getVariableType() == VariableType.VARIABLE){
ArrayHolder h = originalGraph.getVariablesArrays();
if(!setVariableHolder){
originalGraph.setVariablesArrays(new OptimizedGraphArrayHolder(h));
h = originalGraph.getVariablesArrays();
setVariableHolder = true;
}
((OptimizedGraphArrayHolder)h).setFunction(arrayName, fn);
} else {
ArrayHolder h = originalGraph.getConstantArrays();
if(!setConstantHolder){
originalGraph.setConstantArrays(new OptimizedGraphArrayHolder(h));
h = originalGraph.getConstantArrays();
setConstantHolder = true;
}
((OptimizedGraphArrayHolder)h).setFunction(arrayName, fn);
}

return this;
}

}
Loading