diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java index 6414bec2f3f0..defc5bdee102 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java @@ -18,7 +18,7 @@ import lombok.extern.slf4j.Slf4j; import org.nd4j.common.tests.AbstractAssertTestsClass; import org.nd4j.common.tests.BaseND4JTest; -import org.nd4j.imports.TFGraphs.TFGraphTestAllLibnd4j; +import org.nd4j.imports.TFGraphs.TF2ImportTestsSameDiff; import org.nd4j.imports.TFGraphs.TFGraphTestAllSameDiff; import org.nd4j.imports.TFGraphs.TFGraphTestList; import org.nd4j.imports.TFGraphs.TFGraphTestZooModels; @@ -41,7 +41,7 @@ protected Set> getExclusions() { //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) return new HashSet<>(Arrays.asList( TFGraphTestAllSameDiff.class, - TFGraphTestAllLibnd4j.class, + TF2ImportTestsSameDiff.class, TFGraphTestList.class, TFGraphTestZooModels.class, ImportModelDebugger.class //Run manually only, otherwise ignored diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java new file mode 100644 index 000000000000..706f706d6231 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java @@ -0,0 +1,156 @@ +/* ****************************************************************************** + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.imports.TFGraphs; + +import lombok.extern.slf4j.Slf4j; +import org.junit.*; +import org.junit.rules.TestWatcher; +import org.junit.runner.Description; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.OpValidationSuite; +import org.nd4j.common.primitives.Pair; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.executioner.OpExecutioner; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.*; + +@Slf4j +@RunWith(Parameterized.class) +@Ignore //AB 2020/05/12 - Disabled until TF 2.x import test resources are available +public class TF2ImportTestsSameDiff { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests + + @Rule + public TestWatcher testWatcher = new TestWatcher() { + + @Override + protected void starting(Description description){ + log.info("TF2ImportTestsSameDiff: Starting parameterized test: " + description.getDisplayName()); + } + + //protected void failed(Throwable e, Description description) { + //protected void succeeded(Description description) { + }; + + private String modelName; + private TestCase testCase; + private String baseDir; + + private static final TFGraphTestAllHelper.ExecuteWith EXECUTE_WITH = TFGraphTestAllHelper.ExecuteWith.SAMEDIFF; + public static final String[] BASE_DIRS = new String[]{"tf_graphs/examples2.2.0"}; //Add directories for any other TensorFlow versions here + public static final String MODEL_FILENAME = "frozen_model.pb"; + + public static final String[] IGNORE_REGEXES = new String[]{ + + }; + + /* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have + all arrays printed during execution. + If a test name matches any regex here, an ExecPrintListener will be added to the listeners, and all output + arrays will be printed during execution + */ + private final List debugModeRegexes = null; //Arrays.asList("resize_nearest_neighbor/.*", "add_n.*"); + + @BeforeClass + public static void beforeClass() { + Nd4j.setDataType(DataType.FLOAT); + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); + } + + @Before + public void setup() { + Nd4j.setDataType(DataType.FLOAT); + Nd4j.getExecutioner().enableDebugMode(false); + Nd4j.getExecutioner().enableVerboseMode(false); + } + + @Parameterized.Parameters(name="{3}") + public static Collection data() throws Exception { + List out = new ArrayList<>(); + + for(String dir : BASE_DIRS) { + String version = dir.replaceAll("tf_graphs/examples", "tf"); + Map m = TFGraphUtil.getTestCases(dir, false, MODEL_FILENAME); + List l = new ArrayList<>(m.keySet()); + Collections.sort(l); + for (String s : l) { + out.add(new Object[]{s, m.get(s), dir, version + "/" + s}); + } + } + return out; + } + + public TF2ImportTestsSameDiff(String name, TestCase tc, String baseDir, String displayName){ + this.modelName = name; + this.testCase = tc; + this.baseDir = baseDir; + } + + @Test + public void testOutputOnly() throws Exception { + if(TFGraphTestZooModels.isPPC()){ + /* + Ugly hack to temporarily disable tests on PPC only on CI + Issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7657 + These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions + */ + + log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/deeplearning4j/deeplearning4j/issues/7657"); + OpValidationSuite.ignoreFailing(); + } + + + Nd4j.create(1); + + for(String s : IGNORE_REGEXES){ + if(modelName.matches(s)){ + log.info("\n\tIGNORE MODEL ON REGEX: {} - regex {}", modelName, s); + OpValidationSuite.ignoreFailing(); + } + } + Pair precisionOverride = TFGraphTestAllHelper.testPrecisionOverride(modelName); + Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst()); + Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); + + boolean verboseDebugMode = false; + if(debugModeRegexes != null){ + for(String regex : debugModeRegexes){ + if(modelName.matches(regex)){ + verboseDebugMode = true; + break; + } + } + } + + Map inputs = TFGraphUtil.loadInputs(testCase); + Map predictions = TFGraphUtil.loadPredictions(testCase); + + + try { + TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, baseDir, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, verboseDebugMode); + //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir); + } catch (Throwable t){ + log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t); + throw t; + } + //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, EXECUTE_WITH); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index 1cc3baa132ca..bf7eb8accd80 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -19,28 +19,25 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.apache.commons.io.FilenameUtils; -import org.apache.commons.io.IOUtils; -import org.apache.commons.lang3.math.NumberUtils; import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; import org.nd4j.autodiff.execution.NativeGraphExecutioner; - import org.nd4j.autodiff.execution.conf.ExecutionMode; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; -import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.InferenceSession; -import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.memory.ArrayCloseMemoryMgr; import org.nd4j.autodiff.samediff.internal.memory.CloseValidationMemoryMgr; import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.TestCase; import org.nd4j.common.base.Preconditions; -import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener; +import org.nd4j.common.function.BiFunction; +import org.nd4j.common.io.ClassPathResource; +import org.nd4j.common.primitives.Pair; +import org.nd4j.common.resources.Resources; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.listeners.ExecPrintListener; import org.nd4j.linalg.api.buffer.DataType; @@ -48,33 +45,22 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; -import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.function.BiFunction; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; -import org.nd4j.common.io.ClassPathResource; import org.nd4j.linalg.ops.transforms.Transforms; -import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.string.NDArrayStrings; -import org.nd4j.common.util.ArrayUtil; import org.nd4j.nativeblas.NativeOpsHolder; -import org.nd4j.common.resources.strumpf.ResourceFile; -import org.nd4j.common.resources.strumpf.StrumpfResolver; -import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; import org.springframework.core.io.support.PathMatchingResourcePatternResolver; import org.springframework.core.io.support.ResourcePatternResolver; import java.io.*; -import java.net.URI; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.util.*; import java.util.regex.Pattern; import static org.junit.Assert.*; -import static org.nd4j.imports.TFGraphs.TFGraphsSkipNodes.skipNode; /** * Created by susaneraly on 11/6/17. @@ -124,20 +110,6 @@ public void tearDown() { .outputMode(OutputMode.VARIABLE_SPACE) .build(); - protected static List fetchTestParams(String baseDir, String modelFileName, ExecuteWith executeWith, File localTestDir) throws IOException { - String[] modelNames = modelDirNames(baseDir, executeWith, modelFileName); - List modelParams = new ArrayList<>(); - for (int i = 0; i < modelNames.length; i++) { - Object[] currentParams = new Object[4]; - currentParams[0] = inputVars(modelNames[i], baseDir, localTestDir); //input variable map - could be null - currentParams[1] = outputVars(modelNames[i], baseDir, localTestDir); //saved off predictions - currentParams[2] = modelNames[i]; - currentParams[3] = localTestDir; - modelParams.add(currentParams); - } - return modelParams; - } - protected static void checkOnlyOutput(Map inputs, Map predictions, String modelName, String baseDir, String modelFilename, ExecuteWith execType, BiFunction loader, Double maxRelErrorOverride, Double minAbsErrorOverride, boolean printArraysDebugging) throws IOException { @@ -183,7 +155,7 @@ protected static void checkOnlyOutput(Map inputs, Map inputs, Map inputs, Map inputs, String modelName, String baseDir, String modelFileName, - ExecuteWith execType, File localTestDir, boolean printArraysDebugging) throws IOException { - checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir, printArraysDebugging); - } - - public static void checkIntermediate(Map inputs, String modelName, String baseDir, String modelFileName, - ExecuteWith execType, BiFunction loader, - Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir, boolean printArraysDebugging) throws IOException { - Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" + - " must be null or both must be provided"); - Nd4j.EPS_THRESHOLD = 1e-3; - OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order - Pair> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null, printArraysDebugging); - SameDiff graph = p.getFirst(); - Map sdPredictions = p.getSecond(); - - //Collect coverage info about ops - OpValidation.collectTensorflowImportCoverage(graph); - - if (!execType.equals(ExecuteWith.JUST_PRINT)) { - int count = 0; - //Evaluate the nodes in their execution order - this is useful for debugging (as we want the *first* failure - // to be detected before later failures) - List varNames = new ArrayList<>(); - Map fns = graph.getOps(); - List execOrder = listener.getOpNamesList(); - for(String opName : execOrder){ - String[] outputs = graph.getOutputsForOp(fns.get(opName).getOp()); - Collections.addAll(varNames, outputs); - } - - for (String varName : varNames) { - if (!inputs.containsKey(varName)) { //avoiding placeholders - INDArray tfValue = intermediateVars(modelName, baseDir, varName, localTestDir); - if (tfValue == null) { - continue; - } - log.info("Starting check: variable {}", varName); - if (skipNode(modelName, varName)) { - log.info("\n\tFORCING no check on " + varName); - } else { - assertArrayEquals("Shape not equal on node " + varName, tfValue.shape(), graph.getVariable(varName).getShape()); - INDArray sdVal = sdPredictions.get(varName); - if(maxRelErrorOverride != null){ - INDArray diff = Transforms.abs(tfValue.sub(sdVal), false); - INDArray absErrorMask = diff.gte(minAbsErrorOverride); //value 1 if x[i] > minAbsError; value 0 otherwise. Used to get rid of 1e-30 vs. 1e-29 type failures - INDArray sumAbs = Transforms.abs(tfValue, true).addi(Transforms.abs(sdVal, true)); - BooleanIndexing.replaceWhere(sumAbs, 1.0, Conditions.equals(0.0)); //Can only get 0.0 if both are zeros - need to avoid 0/0=NaN - INDArray relError = diff.divi(sumAbs); - relError.muli(absErrorMask); - - int countExceeds = Nd4j.getExecutioner().exec(new MatchCondition(relError, Conditions.greaterThan(maxRelErrorOverride))).getInt(0); - - double maxRE = -1; - //Mainly used for analysis in debugger: - DifferentialFunction op = null; - String[] opInputs = null; - if(countExceeds > 0){ - maxRE = relError.maxNumber().doubleValue(); - //Find the op that this variable is produced by - op = graph.getVariableOutputOp(varName); - opInputs = graph.getInputsForOp(op); - } - - - assertEquals( varName + ": " + countExceeds + " values exceed maxRelError=" + maxRelErrorOverride - + " with minAbsError=" + minAbsErrorOverride + "; largest observed relError=" + maxRE, 0, countExceeds); - } else { -// assertEquals("Value not equal on node " + varName, tfValue, sdVal); - if(tfValue.equals(sdVal)){ - System.out.println("Pass: " + varName); - } else { - System.out.println("FAIL: " + varName); - System.out.println("TF:\n" + tfValue); - System.out.println("SD:\n" + sdVal); - } - - } - log.info("Values and shapes equal for {}", varName); - count++; - } - - } - } - - assertTrue("No intermediate variables were checked", count > 0); - } - - Nd4j.EPS_THRESHOLD = 1e-5; - } - public static Pair> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map inputs, ExecuteWith executeWith, BiFunction graphLoaderFunction, List listeners, Set requiredOutputs, boolean printArraysDebugging) throws IOException { log.info("RUNNING TEST {}...", modelName); - SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName); + File f = Resources.asFile(baseDir + "/" + modelName + "/" + modelFilename); + SameDiff graph = graphLoaderFunction.apply(f, modelName); if(listeners != null){ graph.setListeners(listeners); } @@ -452,353 +334,7 @@ private static String[] modelDirNames(String base_dir, ExecuteWith executeWith, return exampleNames; } - protected static Map inputVars(String modelName, String base_dir, File localTestDir) throws IOException { - return readVars(modelName, base_dir, "**.placeholder", true, localTestDir); - } - - - protected static Map outputVars(String modelName, String base_dir, File localTestDir) throws IOException { - return readVars(modelName, base_dir, "**.prediction", true, localTestDir); - } - - protected static Map inbetweenVars(String modelName, String base_dir, File localTestDir) throws IOException { - return readVars(modelName, base_dir, "**.prediction_inbw", true, localTestDir); - } - - - //return readVars(modelName, base_dir, "**.prediction_inbw", true); - - /** - * Possible for a single node to give multiple outputs - * - * How is a node that has a list of outputs like in the case of "node_multiple_out" work - * Below is hardcoded for a single node - */ - protected static INDArray intermediateVars(String modelName, String base_dir, String varName, File localTestDir) throws IOException { - //convert varName to convention used in naming files - // "/" replaced by "____"; followed by a digit indicating the output number followed by prediction_inbw.(shape|csv) - if (varName.contains(":")) { - varName = varName.replace(':', '.'); - } else { - varName = varName + ".0"; - } - String name = varName.replaceAll("/", "____") + ".prediction_inbw"; - Map nodeSepOutput = readVars(modelName, base_dir, name, true, localTestDir); - - boolean importNameWorkaround = false; - if(nodeSepOutput.isEmpty()){ - //Edge case: intermediates were generated with help of import_graph_def method, which by default adds "import/" to names - // for some reason. https://www.tensorflow.org/api_docs/python/tf/graph_util/import_graph_def - //So many of earlier intermediate nodes test data were generated with filenames like "import___X..." instead of "X..." - name = "import____" + name; - nodeSepOutput = readVars(modelName, base_dir, name, true, localTestDir); - importNameWorkaround = true; - } - - //required check for pattern matching as there are scopes and "*" above is a greedy match - Set removeList = confirmPatternMatch(nodeSepOutput.keySet(), importNameWorkaround ? "import/" + varName : varName); - for (String toRemove : removeList) { - nodeSepOutput.remove(toRemove); - } - if(importNameWorkaround){ - return nodeSepOutput.get("import/" + varName); //this *should* return a list of the indarrays for each node - } else { - return nodeSepOutput.get(varName); //this *should* return a list of the indarrays for each node - } - } - - public static Set confirmPatternMatch(Set setOfNames, String varName) { - Set removeList = new HashSet<>(); - for (String name : setOfNames) { - if (name.equals(varName)) continue; - String[] splitByPeriod = name.split("\\."); - //not a number - maybe another variable deeper in the same scope - if (!NumberUtils.isNumber(splitByPeriod[splitByPeriod.length - 1])) { - removeList.add(name); - } else if (!String.join(".", Arrays.copyOfRange(splitByPeriod, 0, splitByPeriod.length - 1)).equals(varName)) { - removeList.add(name); - } - } - return removeList; - } - - - protected static Map readVars(String modelName, String base_dir, String pattern, boolean recursive, File localTestDir) throws IOException { - Map varMap = new HashMap<>(); - String modelDir = base_dir + "/" + modelName; - - // key is variable name, value is data type - val dtypes = new HashMap(); - - List> resources = new ArrayList<>(); - if(recursive){ - String nameRegex = pattern.replace("**.",".*\\.") + "\\.shape"; -// File baseDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString() + "/" + modelName); -// baseDir.mkdirs(); -// baseDir.deleteOnExit(); -// new ClassPathResource(modelDir).copyDirectory(baseDir); - - // checking out, if local folder declared - String localPath = System.getenv(TFGraphTestAllHelper.resourceFolderVar); - if(localPath != null && (!localPath.contains("src/main/resources") && !localPath.contains("src\\main\\resources"))){ - localPath = FilenameUtils.concat(localPath, "src/main/resources"); - } - - // baseDir will differ, depending on run mode - File baseDir = localPath == null ? new File(localTestDir, "extracted/" + modelName) : new File(localPath, base_dir + "/" + modelName); - String[] arr = baseDir.list(); - - if(!baseDir.exists() || arr == null || arr.length == 0){ - // we're skipping extraction if we're using local copy of dl4j-tests-resources - if (localPath == null) { - baseDir.mkdirs(); - baseDir.deleteOnExit(); - String md = modelDir; - if(!md.endsWith("/") && !md.endsWith("\\")){ - md = md + "/"; - } - new ClassPathResource(md).copyDirectory(baseDir); - } else{ - throw new IllegalStateException("local directory declared but could not find files: " + baseDir.getAbsolutePath()); - } - - } - - LinkedList queue = new LinkedList<>(); - queue.add(baseDir); - - while(!queue.isEmpty()){ - File subdir = queue.remove(); - File[] files = subdir.listFiles(); - if (files != null) { - for (File f : files) { - if (f.isDirectory()) { - queue.add(f); - } else { - String filename = f.getName(); - if(filename.matches(nameRegex)){ - File csvFile = new File(f.getAbsolutePath().replace(".shape",".csv")); - resources.add(new Pair<>(new FileSystemResource(f), new FileSystemResource(csvFile))); - } else if (filename.equals("dtypes")) { - List stringList; - - try (val is = new BufferedInputStream(new FileInputStream(f))) { - stringList = IOUtils.readLines(is, StandardCharsets.UTF_8); - - for (val s:stringList) { - val split = s.split("\\ "); - - val okey = split[0].replaceAll("____", "/"); - // adopt / in names - val key = modelDir + "/" + okey; - - // parse type directly - DataType value = ArrayOptionsHelper.dataType(split[1]); - - // adding key directly - //if (dtypes.containsKey(key)) - // throw new ND4JIllegalStateException("Specified key already exist: [" + key + "]"); - //else - - dtypes.put(key, value); - - // adding zero output duplicate (if it doesn't exist) - if (key.endsWith(".0")) { - val nkey = key.replaceAll("\\.0$",""); - if (!dtypes.containsKey(nkey)) { - dtypes.put(nkey, value); - } - } else if (key.endsWith(":0")) { - val nkey = key.replaceAll(":0$",""); - if (!dtypes.containsKey(nkey)) { - dtypes.put(nkey, value); - } - } - } - } catch (FileNotFoundException e) { - stringList = new ArrayList<>(); - } - } - } - } - } - } - } else { - ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(new ClassPathResource(modelDir).getClassLoader()); - Resource[] r = resolver.getResources("classpath*:" + modelDir + "/" + pattern + ".shape"); - for(Resource res : r){ - String fileName = res.getFilename(); - String varPath = modelDir + "/" + fileName; - Resource r2 = new org.springframework.core.io.ClassPathResource(varPath.replace(".shape", ".csv")); - resources.add(new Pair<>(res, r2)); - } - - } - -// Preconditions.checkState(!dtypes.isEmpty(), "No datatypes file was found"); - val dtype = Nd4j.dataType(); - for (int i = 0; i < resources.size(); i++) { - URI u = resources.get(i).getFirst().getURI(); - String varName = u.toString(); - int idx = varName.indexOf(modelName); - varName = varName.substring(idx + modelName.length()+1); //+1 for "/" - varName = varName.replaceAll("____","/"); - varName = varName.replaceAll(".placeholder.shape",""); - varName = varName.replaceAll(".prediction.shape",""); - varName = varName.replaceAll(".prediction_inbw.shape",""); - - DataType type = dtypes.get(modelDir + "/" + varName); - - List lines; //= FileUtils.readLines(new ClassPathResource(varPath).getFile(), Charset.forName("UTF-8")); - try(InputStream is = new BufferedInputStream(resources.get(i).getFirst().getInputStream())){ - lines = IOUtils.readLines(is, StandardCharsets.UTF_8); - } - List filtered = new ArrayList<>(lines.size()); - for(String s : lines){ - String trimmed = s.trim(); - if(!trimmed.isEmpty()){ - filtered.add(trimmed); - } - } - - if(type == null){ - log.warn("DATATYPE NOT AVAILABLE FOR: {} - {}", modelName, varName); - //Soon: this will be an exception - type = DataType.FLOAT; - } - - INDArray varValue; - if(filtered.size() == 0){ - //Scalar - String content = IOUtils.toString(resources.get(i).getSecond().getInputStream(), StandardCharsets.UTF_8); - switch (type){ - case DOUBLE: - case FLOAT: - case HALF: - case BFLOAT16: - varValue = Nd4j.scalar(type, parseDouble(content)); - break; - case LONG: - case INT: - case SHORT: - case UBYTE: - case BYTE: - case UINT16: - case UINT32: - case UINT64: - varValue = Nd4j.scalar(type, parseLong(content)); - break; - case BOOL: - varValue = Nd4j.scalar(parseBoolean(content)); - break; - case UTF8: - varValue = Nd4j.scalar(content); - break; - case COMPRESSED: - case UNKNOWN: - default: - throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type); - } - } else { - int[] varShape = new int[filtered.size()]; - for( int j=0; j p = resources.get(i); - boolean isRef = p.getSecond().isFile() && !p.getSecond().exists(); - - InputStream stream; - if(isRef){ - //Slight hack for loading strumpf reference files - File r = new StrumpfResolver().localCacheRoot(); - String path = p.getSecond().getFile() + StrumpfResolver.REF; - File f = ResourceFile.fromFile(path).localFile(r); - stream = new BufferedInputStream(new FileInputStream(f)); - } else { - stream = new BufferedInputStream(resources.get(i).getSecond().getInputStream()); - } - - try(InputStream is = stream){ - content = String.join("\n", IOUtils.readLines(is, StandardCharsets.UTF_8)); - } - - if (content.isEmpty()) { - //Should be zeros in shape - boolean foundZero = false; - for( int s : varShape){ - foundZero |= (s == 0); - } - if(foundZero){ - varValue = Nd4j.create(type, ArrayUtil.toLongArray(varShape)); - } else { - throw new IllegalStateException("Empty data but non-empty shape: " + resources.get(i).getSecond()); - } - } else { - if(varShape.length == 1 && varShape[0] == 0) //Annoyingly, some scalars have shape [0] instead of [] - varShape = new int[0]; - - String[] cLines = content.split("\n"); - switch (type){ - case DOUBLE: - case FLOAT: - case HALF: - case BFLOAT16: - double[] dArr = new double[cLines.length]; - int x=0; - while(x < dArr.length){ - dArr[x] = parseDouble(cLines[x]); - x++; - } - varValue = Nd4j.createFromArray(dArr).castTo(type).reshape('c', varShape); - break; - case LONG: - case INT: - case SHORT: - case UBYTE: - case BYTE: - case UINT16: - case UINT32: - case UINT64: - long[] lArr = new long[cLines.length]; - int y=0; - while(y < lArr.length){ - lArr[y] = parseLong(cLines[y]); - y++; - } - varValue = Nd4j.createFromArray(lArr).castTo(type).reshape('c', varShape); - break; - case BOOL: - boolean[] bArr = new boolean[cLines.length]; - int z=0; - while(z < bArr.length){ - bArr[z] = parseBoolean(cLines[z]); - z++; - } - varValue = Nd4j.createFromArray(bArr).reshape('c', varShape); - break; - case UTF8: - varValue = Nd4j.create(cLines).reshape('c', varShape); - break; - case COMPRESSED: - case UNKNOWN: - default: - throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type); - } - } - } catch (NumberFormatException e) { - log.warn("Error parsing number", e); - continue; - } - } - - varMap.put(varName, varValue); - } - return varMap; - } private static long parseLong(String line){ line = line.trim(); //Handle whitespace diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java deleted file mode 100644 index 9b6b0f372846..000000000000 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java +++ /dev/null @@ -1,188 +0,0 @@ -/* ****************************************************************************** - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2019 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.imports.TFGraphs; - -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.junit.*; -import org.junit.rules.TestWatcher; -import org.junit.runner.Description; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.OpValidationSuite; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; -import org.nd4j.nativeblas.NativeOpsHolder; - -import java.io.File; -import java.io.IOException; -import java.util.*; - -/** - * Created by susaneraly on 11/29/17. - */ -@RunWith(Parameterized.class) -@Slf4j -@Ignore("AB 2019/05/21 - JVM Crashes - Issue #7657") -public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests - - @Rule - public TestWatcher testWatcher = new TestWatcher() { - - @Override - protected void starting(Description description){ - log.info("TFGraphTestAllLibnd4j: Starting parameterized test: " + description.getDisplayName()); - } - - //protected void failed(Throwable e, Description description) { - //protected void succeeded(Description description) { - }; - - private Map inputs; - private Map predictions; - private String modelName; - private File localTestDir; - - private static final TFGraphTestAllHelper.ExecuteWith EXECUTE_WITH = TFGraphTestAllHelper.ExecuteWith.LIBND4J; - private static final String BASE_DIR = "tf_graphs/examples"; - private static final String MODEL_FILENAME = "frozen_model.pb"; - - private static final String[] SKIP_FOR_LIBND4J_EXEC = new String[]{ - //Exceptions - need to look into: - "alpha_dropout/.*", - "layers_dropout/.*", - //"losses/.*", - - //These can't pass until this is fixed: https://github.com/deeplearning4j/deeplearning4j/issues/6465#issuecomment-424209155 - //i.e., reduction ops with newFormat/keepDims args - //"l2_normalize/.*", - //"norm_tests/.*", - "g_06", - - //JVM crashes - "simpleif.*", - "simple_cond.*", - - //2019/01/24 - Failing - "cond/cond_true", - "simplewhile_.*", - "simple_while", - "while1/.*", - "while2/a", - - //2019/01/24 - TensorArray support missing at libnd4j exec level?? - "tensor_array/.*", - - //2019/02/04 - Native execution exception: "Graph wasn't toposorted" - "primitive_gru_dynamic", - - //2019/02/08 - Native execution exception: "Graph wasn't toposorted". Note it's only the dynamic (while loop) RNNs - "rnn/basiclstmcell/dynamic.*", - "rnn/basicrnncell/dynamic.*", - "rnn/bidir_basic/dynamic.*", - "rnn/fused_adapt_basic/dynamic.*", - "rnn/grucell/dynamic.*", - "rnn/lstmcell/dynamic.*", - "rnn/srucell/dynamic.*", - - //2019/02/23 Passing for SameDiff exec, failing for libnd4j exec - "rnn/grublockcellv2/.*", - "rnn/lstmblockcell/.*", - "rnn/lstmblockfusedcell/.*", - }; - - @BeforeClass - public static void beforeClass() { - Nd4j.setDataType(DataType.FLOAT); - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); - } - - @Before - public void setup(){ - Nd4j.setDataType(DataType.FLOAT); - } - - @After - public void tearDown() { - NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); - NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); - } - - @Parameterized.Parameters(name="{2}") - public static Collection data() throws IOException { - val localPath = System.getenv(TFGraphTestAllHelper.resourceFolderVar); - - // if this variable isn't set - we're using dl4j-tests-resources - if (localPath == null) { - File baseDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); - return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); - } else { - File baseDir = new File(localPath); - return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); - } - } - - public TFGraphTestAllLibnd4j(Map inputs, Map predictions, String modelName, File localTestDir) { - this.inputs = inputs; - this.predictions = predictions; - this.modelName = modelName; - this.localTestDir = localTestDir; - } - - @Test//(timeout = 25000L) - public void test() throws Exception { - if(TFGraphTestZooModels.isPPC()){ - /* - Ugly hack to temporarily disable tests on PPC only on CI - Issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7657 - These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions - */ - - log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/deeplearning4j/deeplearning4j/issues/7657"); - OpValidationSuite.ignoreFailing(); - } - - Nd4j.create(1); - for(String s : TFGraphTestAllSameDiff.IGNORE_REGEXES){ - if(modelName.matches(s)){ - log.info("\n\tIGNORE MODEL ON REGEX: {} - regex {}", modelName, s); - OpValidationSuite.ignoreFailing(); - } - } - - for(String s : SKIP_FOR_LIBND4J_EXEC){ - if(modelName.matches(s)){ - log.info("\n\tIGNORE MODEL ON REGEX - SKIP LIBND4J EXEC ONLY: {} - regex {}", modelName, s); - OpValidationSuite.ignoreFailing(); - } - } - - log.info("Starting test: {}", this.modelName); - Pair precisionOverride = TFGraphTestAllHelper.testPrecisionOverride(modelName); - Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst()); - Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); - - TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, - TFGraphTestAllHelper.LOADER, maxRE, minAbs, false); - //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, EXECUTE_WITH); - } - -} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 72c705852f57..21fefdaab9eb 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -54,14 +54,12 @@ protected void starting(Description description){ //protected void succeeded(Description description) { }; - private Map inputs; - private Map predictions; private String modelName; - private File localTestDir; + private TestCase testCase; private static final TFGraphTestAllHelper.ExecuteWith EXECUTE_WITH = TFGraphTestAllHelper.ExecuteWith.SAMEDIFF; - private static final String BASE_DIR = "tf_graphs/examples"; - private static final String MODEL_FILENAME = "frozen_model.pb"; + public static final String BASE_DIR = "tf_graphs/examples"; + public static final String MODEL_FILENAME = "frozen_model.pb"; public static final String[] IGNORE_REGEXES = new String[]{ //Failing 2019/07/01 - Issue 10, https://github.com/deeplearning4j/deeplearning4j/issues/6958 @@ -203,26 +201,22 @@ public void setup() { public void tearDown() { } - @Parameterized.Parameters(name="{2}") - public static Collection data() throws IOException { - val localPath = System.getenv(TFGraphTestAllHelper.resourceFolderVar); - - // if this variable isn't set - we're using dl4j-tests-resources - if (localPath == null) { - File baseDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); - List params = TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); - return params; - } else { - File baseDir = new File(localPath); - return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); + @Parameterized.Parameters(name="{0}") + public static Collection data() throws Exception { + Map m = TFGraphUtil.getTestCases(BASE_DIR, false, MODEL_FILENAME); + List l = new ArrayList<>(m.keySet()); + Collections.sort(l); + + List out = new ArrayList<>(l.size()); + for(String s : l){ + out.add(new Object[]{s, m.get(s)}); } + return out; } - public TFGraphTestAllSameDiff(Map inputs, Map predictions, String modelName, File localTestDir) { - this.inputs = inputs; - this.predictions = predictions; - this.modelName = modelName; - this.localTestDir = localTestDir; + public TFGraphTestAllSameDiff(String name, TestCase tc){ + this.modelName = name; + this.testCase = tc; } @Test//(timeout = 25000L) @@ -261,6 +255,10 @@ public void testOutputOnly() throws Exception { } } + Map inputs = TFGraphUtil.loadInputs(testCase); + Map predictions = TFGraphUtil.loadPredictions(testCase); + + try { TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, verboseDebugMode); //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java index 52aed2d1c210..1e814fd88a4d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java @@ -29,10 +29,7 @@ import java.io.File; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Map; +import java.util.*; /** * TFGraphTestAll* will run all the checked in TF graphs and @@ -55,7 +52,7 @@ public class TFGraphTestList { public static final boolean printArraysDebugging = false; public static String[] modelNames = new String[]{ - "resize_nearest_neighbor/int32" + "emptyArrayTests/identity_n/rank1" }; @After @@ -78,27 +75,32 @@ public static void beforeClass(){ } private String modelName; + private TestCase testCase; - @Parameterized.Parameters - public static Collection data() { - List modelNamesParams = new ArrayList<>(); - for (int i = 0; i < modelNames.length; i++) { - Object[] currentParams = new String[]{modelNames[i]}; - modelNamesParams.add(currentParams); + @Parameterized.Parameters(name="{0}") + public static Collection data() throws Exception { + + List out = new ArrayList<>(modelNames.length); + for(int i=0; i inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir); - Map predictions = TFGraphTestAllHelper.outputVars(modelName, MODEL_DIR, dir); + Map inputs = TFGraphUtil.loadInputs(testCase); + Map predictions = TFGraphUtil.loadPredictions(testCase); + Pair precisionOverride = TFGraphTestAllHelper.testPrecisionOverride(modelName); Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst()); Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); @@ -106,12 +108,4 @@ public void testOutputOnly() throws IOException { TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, TFGraphTestAllHelper.LOADER, maxRE, minAbs, printArraysDebugging); } - - @Test @Ignore - public void testAlsoIntermediate() throws IOException { - //Nd4jCpu.Environment.getInstance().setUseMKLDNN(false); - File dir = testDir.newFolder(); - Map inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir); - TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir, printArraysDebugging); - } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java index d8550673e95b..b7ca3e3d3dfa 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java @@ -43,10 +43,7 @@ import java.io.IOException; import java.net.URL; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Map; +import java.util.*; @RunWith(Parameterized.class) @Slf4j @@ -103,9 +100,8 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we private static final String BASE_DIR = "tf_graphs/zoo_models"; private static final String MODEL_FILENAME = "tf_model.txt"; - private Map inputs; - private Map predictions; private String modelName; + private TestCase testCase; private File localTestDir; public static String getBaseModelDir(){ @@ -206,18 +202,26 @@ public static void beforeClass(){ Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); } - @Parameterized.Parameters(name="{2}") - public static Collection data() throws IOException { + @Parameterized.Parameters(name="{0}") + public static Collection data() throws Exception { + Map m = TFGraphUtil.getTestCases(BASE_DIR, false, MODEL_FILENAME); + classTestDir.create(); File baseDir = classTestDir.newFolder(); // new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); - List params = TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF, baseDir); - return params; + + List out = new ArrayList<>(); + List l = new ArrayList<>(m.keySet()); + Collections.sort(l); + for (String s : l) { + out.add(new Object[]{s, m.get(s), baseDir}); + } + + return out; } - public TFGraphTestZooModels(Map inputs, Map predictions, String modelName, File localTestDir) { - this.inputs = inputs; - this.predictions = predictions; + public TFGraphTestZooModels(String modelName, TestCase tc, File localTestDir) { this.modelName = modelName; + this.testCase = tc; this.localTestDir = localTestDir; } @@ -266,6 +270,10 @@ public void testOutputOnly() throws Exception { Double maxRE = 1e-3; Double minAbs = 1e-4; currentTestDir = testDir.newFolder(); + + Map inputs = TFGraphUtil.loadInputs(testCase); + Map predictions = TFGraphUtil.loadPredictions(testCase); + log.info("----- SameDiff Exec: {} -----", modelName); TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF, LOADER, maxRE, minAbs, false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java new file mode 100644 index 000000000000..dc721572436a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java @@ -0,0 +1,339 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.imports.TFGraphs; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.FilenameUtils; +import org.nd4j.common.base.Preconditions; +import org.nd4j.common.resources.Resources; +import org.nd4j.common.tests.ResourceUtils; +import org.nd4j.common.util.ArrayUtil; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; +import org.nd4j.linalg.factory.Nd4j; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.*; + +@Slf4j +public class TFGraphUtil { + + private TFGraphUtil() { + } + + public static TestCase getTestCase(String baseDir, String testName, String modelFilename) throws Exception { + String newBase = FilenameUtils.concat(baseDir, testName + "/"); + Map cases = getTestCases(newBase, true, modelFilename); + Preconditions.checkState(cases.size() == 1, "Expected 1 test case, got %s", cases.size()); + return cases.get(cases.keySet().iterator().next()); + } + + public static Map getTestCases(String baseDir, boolean singleTest, String modelFilename) throws Exception { + + baseDir = baseDir.replaceAll("\\\\", "/"); + if (!baseDir.endsWith("/")) + baseDir += "/"; + + List l = ResourceUtils.listClassPathFiles(baseDir, true, false); + + Set set = new HashSet<>(l); + Map map = new HashMap<>(); + + for (String s : l) { + String sub = s.substring(baseDir.length()); + + int idx = singleTest ? 0 : sub.lastIndexOf('/'); + + boolean badTest = false; + String name = null; + String modelDir = null; + if (singleTest) { + name = ""; + modelDir = baseDir; + } else if (idx > 0) { + name = sub.substring(0, idx); + modelDir = baseDir + sub.substring(0, idx + 1); + String expModel = modelDir + modelFilename; + while (!set.contains(expModel) && idx > 0) { + //Due to a mixing of directories and variable names - we + //For example we might have "X/frozen_model.pb" + //And then also "X/something/or/other.csv + //When this occurs - we should look up the path to determine which part is the model name + // and which part is the variable name + idx = sub.lastIndexOf('/', idx); + if (idx < 0) { + System.out.println("***** BAD TEST DIRECTORY: " + s + " ******"); + badTest = true; + break; + } + + sub = sub.substring(0, idx); + expModel = baseDir + sub + "/" + modelFilename; + modelDir = baseDir + sub + "/"; + name = sub; + } + } + + if(badTest || modelDir == null) + continue; + + TestCase tc = map.get(name); + if (tc == null) { + tc = new TestCase(name, null, null, null); + map.put(name, tc); + } + + if (s.endsWith("prediction.csv")) { + if (tc.outputs == null) + tc.outputs = new HashMap<>(); + String varName = s.substring(modelDir.length()).replaceAll("____", "/"); + varName = varName.substring(0, varName.length() - "prediction.csv".length() - 1); + tc.outputs.put(varName, s); + } else if (s.endsWith("placeholder.csv")) { + if (tc.inputs == null) + tc.inputs = new HashMap<>(); + String varName = s.substring(modelDir.length()).replaceAll("____", "/"); + varName = varName.substring(0, varName.length() - "placeholder.csv".length() - 1); + tc.inputs.put(varName, s); + } else if (s.endsWith("/dtypes")) { + File f = Resources.asFile(s); + List lines = FileUtils.readLines(f, StandardCharsets.UTF_8); + tc.datatypes = new HashMap<>(); + for (String line : lines) { + String[] split = line.split(" "); + Preconditions.checkState(split.length == 2, "Expected 2 entries in dtypes file, got %s", split.length); + String key = split[0].replaceAll("____", "/"); + DataType value = ArrayOptionsHelper.dataType(split[1]); + + // adding zero output duplicate (if it doesn't exist) + if (key.endsWith(".0")) { + val nkey = key.replaceAll("\\.0$", ""); + if (!tc.datatypes.containsKey(nkey)) { + tc.datatypes.put(nkey, value); + } + } else if (key.endsWith(":0")) { + val nkey = key.replaceAll(":0$", ""); + if (!tc.datatypes.containsKey(nkey)) { + tc.datatypes.put(nkey, value); + } + } else { + tc.datatypes.put(split[0], value); + } + } + } + } + return map; + } + + private static long parseLong(String line) { + line = line.trim(); //Handle whitespace + if (line.matches("-?\\d+\\.0+")) { + //Annoyingly, some integer data is stored with redundant/unnecessary zeros - like "-7.0000000" + return Long.parseLong(line.substring(0, line.indexOf('.'))); + } else { + return Long.parseLong(line); + } + } + + private static double parseDouble(String line) { + line = line.trim(); //Handle whitespace - some lines are like " -inf" + if ("nan".equalsIgnoreCase(line)) { + return Double.NaN; + } else if ("inf".equalsIgnoreCase(line)) { + return Double.POSITIVE_INFINITY; + } else if ("-inf".equalsIgnoreCase(line)) { + return Double.NEGATIVE_INFINITY; + } else { + return Double.parseDouble(line); + } + } + + private static boolean parseBoolean(String line) { + line = line.trim(); + if (line.matches("1(\\.0*)?")) { //Booleans are ocassionally represented like 1.000000 or 0.000000 + return true; + } else if (line.matches("0(\\.0*)?")) { + return false; + } + return Boolean.parseBoolean(line); + } + + public static INDArray loadCsv(String path, String varName, @NonNull TestCase tc) throws IOException { + + DataType type; + if(tc.datatypes == null){ + log.warn("No datatype available for: {}", path); + type = DataType.FLOAT; + } else { + type = tc.datatypes.get(varName); + } + + String shapeFile = path.substring(0, path.length() - 4) + ".shape"; + List shapeLines = FileUtils.readLines(Resources.asFile(shapeFile), StandardCharsets.UTF_8); + List filteredShape = new ArrayList<>(shapeLines.size()); + for (String s : shapeLines) { + String trimmed = s.trim(); + if (!trimmed.isEmpty()) { + filteredShape.add(trimmed); + } + } + + INDArray varValue = null; + if (filteredShape.size() == 0) { + //Scalar + String content = FileUtils.readFileToString(Resources.asFile(path), StandardCharsets.UTF_8); //IOUtils.toString(resources.get(i).getSecond().getInputStream(), StandardCharsets.UTF_8); + switch (type) { + case DOUBLE: + case FLOAT: + case HALF: + case BFLOAT16: + varValue = Nd4j.scalar(type, parseDouble(content)); + break; + case LONG: + case INT: + case SHORT: + case UBYTE: + case BYTE: + case UINT16: + case UINT32: + case UINT64: + varValue = Nd4j.scalar(type, parseLong(content)); + break; + case BOOL: + varValue = Nd4j.scalar(parseBoolean(content)); + break; + case UTF8: + varValue = Nd4j.scalar(content); + break; + case COMPRESSED: + case UNKNOWN: + default: + throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type); + } + } else { + int[] varShape = new int[filteredShape.size()]; + for (int j = 0; j < filteredShape.size(); j++) { + varShape[j] = Integer.parseInt(filteredShape.get(j)); + } + + try { + String content = FileUtils.readFileToString(Resources.asFile(path), StandardCharsets.UTF_8); + + if (content.isEmpty()) { + //Should be zeros in shape + boolean foundZero = false; + for (int s : varShape) { + foundZero |= (s == 0); + } + if (foundZero) { + varValue = Nd4j.create(type, ArrayUtil.toLongArray(varShape)); + } else { + throw new IllegalStateException("Empty data but non-empty shape: " + shapeFile); + } + } else { + if (varShape.length == 1 && varShape[0] == 0) //Annoyingly, some scalars have shape [0] instead of [] + varShape = new int[0]; + + String[] cLines = content.split("\n"); + switch (type) { + case DOUBLE: + case FLOAT: + case HALF: + case BFLOAT16: + double[] dArr = new double[cLines.length]; + int x = 0; + while (x < dArr.length) { + dArr[x] = parseDouble(cLines[x]); + x++; + } + varValue = Nd4j.createFromArray(dArr).castTo(type).reshape('c', varShape); + break; + case LONG: + case INT: + case SHORT: + case UBYTE: + case BYTE: + case UINT16: + case UINT32: + case UINT64: + long[] lArr = new long[cLines.length]; + int y = 0; + while (y < lArr.length) { + lArr[y] = parseLong(cLines[y]); + y++; + } + varValue = Nd4j.createFromArray(lArr).castTo(type).reshape('c', varShape); + break; + case BOOL: + boolean[] bArr = new boolean[cLines.length]; + int z = 0; + while (z < bArr.length) { + bArr[z] = parseBoolean(cLines[z]); + z++; + } + varValue = Nd4j.createFromArray(bArr).reshape('c', varShape); + break; + case UTF8: + varValue = Nd4j.create(cLines).reshape('c', varShape); + break; + case COMPRESSED: + case UNKNOWN: + default: + throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type); + } + } + } catch (NumberFormatException e) { + log.warn("Error parsing number", e); +// continue; + } + } + + return varValue; + } + + + public static Map loadInputs(TestCase testCase) throws IOException { + Map inputs = null; + if(testCase.inputs != null){ + inputs = new HashMap<>(); + for(String s : testCase.inputs.keySet()){ + String path = testCase.inputs.get(s); + INDArray arr = TFGraphUtil.loadCsv(path, s, testCase); + inputs.put(s, arr); + } + } + return inputs; + } + + public static Map loadPredictions(TestCase testCase) throws IOException { + Map predictions = null; + if(testCase.outputs != null){ + predictions = new HashMap<>(); + for(String s : testCase.outputs.keySet()){ + String path = testCase.outputs.get(s); + INDArray arr = TFGraphUtil.loadCsv(path, s, testCase); + predictions.put(s, arr); + } + } + return predictions; + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestCase.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestCase.java new file mode 100644 index 000000000000..d7b23660a56f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestCase.java @@ -0,0 +1,31 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.imports.TFGraphs; + +import lombok.AllArgsConstructor; +import lombok.Data; +import org.nd4j.linalg.api.buffer.DataType; + +import java.util.Map; + +@AllArgsConstructor +@Data +public class TestCase { + public String modelName; + public Map inputs; //Key: variable name, values: filename (.csv) + public Map outputs; + public Map datatypes; +} diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java index bafe94094d41..233e99595ac5 100644 --- a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java @@ -61,11 +61,17 @@ public static List listClassPathFiles(String path, boolean recursive, bo private static List listClassPathFilesHelper(String path, boolean recursive, boolean includeDirectories, String... extensions) throws IOException { ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(new ClassPathResource(path).getClassLoader()); + if(path.contains("\\")) + path = path.replaceAll("\\\\", "/"); + + if(!path.endsWith("/")) + path = path + "/"; + StringBuilder sbPattern = new StringBuilder("classpath*:" + path); if (recursive) { - sbPattern.append("/**/*"); + sbPattern.append("**/*"); } else { - sbPattern.append("/*"); + sbPattern.append("*"); } //Normalize extensions so they are all like ".csv" etc - with leading "."