From 3918677150827407053280600f753986aa210a16 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 11 May 2020 22:14:07 +1000 Subject: [PATCH 1/8] Initial refactoring for TF 1 & 2 import Signed-off-by: Alex Black --- .../TFGraphs/TFGraphTestAllSameDiff.java | 79 +++-- .../imports/TFGraphs/TFGraphTestList.java | 2 +- .../nd4j/imports/TFGraphs/TFGraphUtil.java | 316 ++++++++++++++++++ .../org/nd4j/common/tests/ResourceUtils.java | 7 +- 4 files changed, 381 insertions(+), 23 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index b76b45dcdafe..4732f7f91c12 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -54,14 +54,15 @@ protected void starting(Description description){ //protected void succeeded(Description description) { }; - private Map inputs; - private Map predictions; +// private Map inputs; +// private Map predictions; private String modelName; private File localTestDir; + private TestCase testCase; private static final TFGraphTestAllHelper.ExecuteWith EXECUTE_WITH = TFGraphTestAllHelper.ExecuteWith.SAMEDIFF; private static final String BASE_DIR = "tf_graphs/examples"; - private static final String MODEL_FILENAME = "frozen_model.pb"; + public static final String MODEL_FILENAME = "frozen_model.pb"; public static final String[] IGNORE_REGEXES = new String[]{ //Failing 2019/07/01 - Issue 10, https://github.com/deeplearning4j/deeplearning4j/issues/6958 @@ -150,26 +151,47 @@ public void setup() { public void tearDown() { } - @Parameterized.Parameters(name="{2}") - public static Collection data() throws IOException { - val localPath = System.getenv(TFGraphTestAllHelper.resourceFolderVar); - - // if this variable isn't set - we're using dl4j-tests-resources - if (localPath == null) { - File baseDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); - List params = TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); - return params; - } else { - File baseDir = new File(localPath); - return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); + @Parameterized.Parameters(name="{0}") + public static Collection data() throws Exception { +// long start = System.currentTimeMillis(); +// val localPath = System.getenv(TFGraphTestAllHelper.resourceFolderVar); +// +// // if this variable isn't set - we're using dl4j-tests-resources +// if (localPath == null) { +// File baseDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); +// List params = TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); +// long end = System.currentTimeMillis(); +// System.out.println("TIME TO GET PARAMETERS: " + (end-start)); +// return params; +// } else { +// File baseDir = new File(localPath); +// List l = TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); +// long end = System.currentTimeMillis(); +// System.out.println("TIME TO GET PARAMETERS: " + (end-start)); +// return l; +// } + + Map m = TFGraphUtil.getTestCases(BASE_DIR); + List l = new ArrayList<>(m.keySet()); + Collections.sort(l); + + List out = new ArrayList<>(l.size()); + for(String s : l){ + out.add(new Object[]{s, m.get(s)}); } + return out; } - public TFGraphTestAllSameDiff(Map inputs, Map predictions, String modelName, File localTestDir) { - this.inputs = inputs; - this.predictions = predictions; - this.modelName = modelName; - this.localTestDir = localTestDir; +// public TFGraphTestAllSameDiff(Map inputs, Map predictions, String modelName, File localTestDir) { +// this.inputs = inputs; +// this.predictions = predictions; +// this.modelName = modelName; +// this.localTestDir = localTestDir; +// } + + public TFGraphTestAllSameDiff(String name, TestCase tc){ + this.modelName = name; + this.testCase = tc; } @Test//(timeout = 25000L) @@ -208,6 +230,23 @@ public void testOutputOnly() throws Exception { } } + Map inputs = new HashMap<>(); + Map predictions = new HashMap<>(); + + if(testCase.inputs != null){ + for(String s : testCase.inputs.keySet()){ + INDArray arr = TFGraphUtil.loadCsv(s, testCase); + inputs.put(s, arr); + } + } + + if(testCase.outputs != null){ + for(String s : testCase.outputs.keySet()){ + INDArray arr = TFGraphUtil.loadCsv(s, testCase); + predictions.put(s, arr); + } + } + try { TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, verboseDebugMode); //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java index 52aed2d1c210..2c7f1e8575cd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java @@ -55,7 +55,7 @@ public class TFGraphTestList { public static final boolean printArraysDebugging = false; public static String[] modelNames = new String[]{ - "resize_nearest_neighbor/int32" + "bitcast/from_int32_to_uint32" }; @After diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java new file mode 100644 index 000000000000..de1d73c5521f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java @@ -0,0 +1,316 @@ +package org.nd4j.imports.TFGraphs; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.apache.commons.io.FileUtils; +import org.nd4j.common.base.Preconditions; +import org.nd4j.common.resources.Resources; +import org.nd4j.common.tests.ResourceUtils; +import org.nd4j.common.util.ArrayUtil; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; +import org.nd4j.linalg.factory.Nd4j; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.*; + +@Slf4j +public class TFGraphUtil { + + private TFGraphUtil(){ } + + public static Map getTestCases(String baseDir) throws Exception { + + long start = System.currentTimeMillis(); +// String baseDir = "tf_graphs/examples/"; + List l = ResourceUtils.listClassPathFiles(baseDir, true, false); + long end = System.currentTimeMillis(); + + Set listAsSet = new HashSet<>(l); + + + Set modelSet = new HashSet<>(); + List modelFileNames = new ArrayList<>(); + + + Map map = new HashMap<>(); + + long start2 = System.currentTimeMillis(); + for(String s : l){ + String sub = s.substring(baseDir.length()+1); +// int idx = sub.indexOf('/'); +// if(idx > 0){ +// String model +// } + + int idx = sub.lastIndexOf('/'); + + + if(idx > 0) { + String name = sub.substring(0, idx); + String modelDir = baseDir + sub.substring(0,idx+1); + String expModel = modelDir + TFGraphTestAllSameDiff.MODEL_FILENAME; +// while(!Resources.exists(expModel) && idx > 0){ + while(!listAsSet.contains(expModel) && idx > 0){ + //Due to a mixing of directories and variable names - we + //For example we might have "X/frozen_model.pb" + //And then also "X/something/or/other.csv + //When this occurs - we should look up the path to determine which part is the model name + // and which part is the variable name + idx = sub.lastIndexOf('/', idx); + if(idx < 0){ + System.out.println("***** BAD TEST DIRECTORY: " + s + " ******"); + continue; + } + + sub = sub.substring(0, idx); + expModel = baseDir + "/" + sub + "/" + TFGraphTestAllSameDiff.MODEL_FILENAME; + } + + + modelSet.add(name); + + TestCase tc = map.get(name); + if(tc == null){ + tc = new TestCase(name, null, null, null); + map.put(name, tc); + } + + if(s.endsWith("prediction.csv")){ + if(tc.outputs == null) + tc.outputs = new HashMap<>(); + String varName = s.substring(modelDir.length()).replaceAll("____", "/"); +// String varName = sub.substring(idx+1).replaceAll("____", "/"); + varName = varName.substring(0, varName.length() - "prediction.csv".length() - 1); + tc.outputs.put(varName, s); + } else if(s.endsWith("placeholder.shape")){ + if(tc.inputs == null) + tc.inputs = new HashMap<>(); +// String varName = sub.substring(idx+1).replaceAll("____", "/"); + String varName = s.substring(modelDir.length()).replaceAll("____", "/"); + varName = varName.substring(0, varName.length() - "placeholder.shape".length() - 1); + tc.inputs.put(varName, s); + } else if(s.endsWith("/dtypes")){ + File f = Resources.asFile(s); + List lines = FileUtils.readLines(f, StandardCharsets.UTF_8); + tc.datatypes = new HashMap<>(); + for(String line : lines){ + String[] split = line.split(" "); + Preconditions.checkState(split.length == 2, "Expected 2 entries in dtypes file, got %s", split.length); + String key = split[0].replaceAll("____", "/"); + DataType value = ArrayOptionsHelper.dataType(split[1]); + + // adding zero output duplicate (if it doesn't exist) + if (key.endsWith(".0")) { + val nkey = key.replaceAll("\\.0$",""); + if (!tc.datatypes.containsKey(nkey)) { + tc.datatypes.put(nkey, value); + } + } else if (key.endsWith(":0")) { + val nkey = key.replaceAll(":0$",""); + if (!tc.datatypes.containsKey(nkey)) { + tc.datatypes.put(nkey, value); + } + } + + tc.datatypes.put(line, null); + } + } +// System.out.println(sub); + } + } + long end2 = System.currentTimeMillis(); + + System.out.println("List duration: " + (end-start)); + System.out.println("Process duration: " + (end2-start2)); + return map; + } + + private static long parseLong(String line){ + line = line.trim(); //Handle whitespace + if(line.matches("-?\\d+\\.0+")){ + //Annoyingly, some integer data is stored with redundant/unnecessary zeros - like "-7.0000000" + return Long.parseLong(line.substring(0, line.indexOf('.'))); + } else { + return Long.parseLong(line); + } + } + + private static double parseDouble(String line){ + line = line.trim(); //Handle whitespace - some lines are like " -inf" + if("nan".equalsIgnoreCase(line)){ + return Double.NaN; + } else if("inf".equalsIgnoreCase(line)) { + return Double.POSITIVE_INFINITY; + } else if("-inf".equalsIgnoreCase(line)){ + return Double.NEGATIVE_INFINITY; + } else { + return Double.parseDouble(line); + } + } + + private static boolean parseBoolean(String line){ + line = line.trim(); + if(line.matches("1(\\.0*)?")){ //Booleans are ocassionally represented like 1.000000 or 0.000000 + return true; + } else if(line.matches("0(\\.0*)?")){ + return false; + } + return Boolean.parseBoolean(line); + } + + public static INDArray loadCsv(String path, TestCase tc) throws IOException { + + DataType type = tc.datatypes.get(path); + + String shapeFile = path.substring(0, path.length()-4) + ".shape"; + List shapeLines = FileUtils.readLines(Resources.asFile(shapeFile), StandardCharsets.UTF_8); + List filteredShape = new ArrayList<>(shapeLines.size()); + for(String s : shapeLines){ + String trimmed = s.trim(); + if(!trimmed.isEmpty()){ + filteredShape.add(trimmed); + } + } + + if(type == null){ + log.warn("DATATYPE NOT AVAILABLE FOR: {} - {}", tc.modelName, path); + //Soon: this will be an exception + type = DataType.FLOAT; + } + + INDArray varValue = null; + if(filteredShape.size() == 0){ + //Scalar + String content = FileUtils.readFileToString(Resources.asFile(path), StandardCharsets.UTF_8); //IOUtils.toString(resources.get(i).getSecond().getInputStream(), StandardCharsets.UTF_8); + switch (type){ + case DOUBLE: + case FLOAT: + case HALF: + case BFLOAT16: + varValue = Nd4j.scalar(type, parseDouble(content)); + break; + case LONG: + case INT: + case SHORT: + case UBYTE: + case BYTE: + case UINT16: + case UINT32: + case UINT64: + varValue = Nd4j.scalar(type, parseLong(content)); + break; + case BOOL: + varValue = Nd4j.scalar(parseBoolean(content)); + break; + case UTF8: + varValue = Nd4j.scalar(content); + break; + case COMPRESSED: + case UNKNOWN: + default: + throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type); + } + } else { + int[] varShape = new int[filteredShape.size()]; + for( int j=0; j p = resources.get(i); +// boolean isRef = p.getSecond().isFile() && !p.getSecond().exists(); +// +// InputStream stream; +// if(isRef){ +// //Slight hack for loading strumpf reference files +// File r = new StrumpfResolver().localCacheRoot(); +// String path = p.getSecond().getFile() + StrumpfResolver.REF; +// File f = ResourceFile.fromFile(path).localFile(r); +// stream = new BufferedInputStream(new FileInputStream(f)); +// } else { +// stream = new BufferedInputStream(resources.get(i).getSecond().getInputStream()); +// } +// +// try(InputStream is = stream){ +// content = String.join("\n", IOUtils.readLines(is, StandardCharsets.UTF_8)); +// } + String content = FileUtils.readFileToString(Resources.asFile(path), StandardCharsets.UTF_8); + + if (content.isEmpty()) { + //Should be zeros in shape + boolean foundZero = false; + for( int s : varShape){ + foundZero |= (s == 0); + } + if(foundZero){ + varValue = Nd4j.create(type, ArrayUtil.toLongArray(varShape)); + } else { + throw new IllegalStateException("Empty data but non-empty shape: " + shapeFile); + } + } else { + if(varShape.length == 1 && varShape[0] == 0) //Annoyingly, some scalars have shape [0] instead of [] + varShape = new int[0]; + + String[] cLines = content.split("\n"); + switch (type){ + case DOUBLE: + case FLOAT: + case HALF: + case BFLOAT16: + double[] dArr = new double[cLines.length]; + int x=0; + while(x < dArr.length){ + dArr[x] = parseDouble(cLines[x]); + x++; + } + varValue = Nd4j.createFromArray(dArr).castTo(type).reshape('c', varShape); + break; + case LONG: + case INT: + case SHORT: + case UBYTE: + case BYTE: + case UINT16: + case UINT32: + case UINT64: + long[] lArr = new long[cLines.length]; + int y=0; + while(y < lArr.length){ + lArr[y] = parseLong(cLines[y]); + y++; + } + varValue = Nd4j.createFromArray(lArr).castTo(type).reshape('c', varShape); + break; + case BOOL: + boolean[] bArr = new boolean[cLines.length]; + int z=0; + while(z < bArr.length){ + bArr[z] = parseBoolean(cLines[z]); + z++; + } + varValue = Nd4j.createFromArray(bArr).reshape('c', varShape); + break; + case UTF8: + varValue = Nd4j.create(cLines).reshape('c', varShape); + break; + case COMPRESSED: + case UNKNOWN: + default: + throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type); + } + } + } catch (NumberFormatException e) { + log.warn("Error parsing number", e); +// continue; + } + } + + return varValue; + } + +} diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java index bafe94094d41..a71fe3939eec 100644 --- a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java @@ -61,11 +61,14 @@ public static List listClassPathFiles(String path, boolean recursive, bo private static List listClassPathFilesHelper(String path, boolean recursive, boolean includeDirectories, String... extensions) throws IOException { ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(new ClassPathResource(path).getClassLoader()); + if(!path.endsWith("/")) + path = path + "/"; + StringBuilder sbPattern = new StringBuilder("classpath*:" + path); if (recursive) { - sbPattern.append("/**/*"); + sbPattern.append("**/*"); } else { - sbPattern.append("/*"); + sbPattern.append("*"); } //Normalize extensions so they are all like ".csv" etc - with leading "." From e562542ee8796c117c640fc8fa38e44515e9b3d6 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 12 May 2020 16:19:16 +1000 Subject: [PATCH 2/8] Refactoring for TF import tests Signed-off-by: Alex Black --- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 2 +- .../TFGraphs/TFGraphTestAllHelper.java | 6 +- .../TFGraphs/TFGraphTestAllSameDiff.java | 46 +--- .../imports/TFGraphs/TFGraphTestList.java | 36 +-- .../nd4j/imports/TFGraphs/TFGraphUtil.java | 239 +++++++++++------- .../org/nd4j/common/tests/ResourceUtils.java | 3 + 6 files changed, 177 insertions(+), 155 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 835a2f4cb8c9..05225173430d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -5521,7 +5521,7 @@ public boolean isS() { public INDArray castTo(DataType dataType) { if(dataType == dataType()) //No-op if correct datatype return this; - if(isEmpty()){ + if(isEmpty() && rank() == 0){ return Nd4j.empty(dataType); } val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index 1cc3baa132ca..b8a87b00c2c5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -183,7 +183,7 @@ protected static void checkOnlyOutput(Map inputs, Map inputs, Map data() throws Exception { -// long start = System.currentTimeMillis(); -// val localPath = System.getenv(TFGraphTestAllHelper.resourceFolderVar); -// -// // if this variable isn't set - we're using dl4j-tests-resources -// if (localPath == null) { -// File baseDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); -// List params = TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); -// long end = System.currentTimeMillis(); -// System.out.println("TIME TO GET PARAMETERS: " + (end-start)); -// return params; -// } else { -// File baseDir = new File(localPath); -// List l = TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); -// long end = System.currentTimeMillis(); -// System.out.println("TIME TO GET PARAMETERS: " + (end-start)); -// return l; -// } - - Map m = TFGraphUtil.getTestCases(BASE_DIR); + Map m = TFGraphUtil.getTestCases(BASE_DIR, false); List l = new ArrayList<>(m.keySet()); Collections.sort(l); @@ -182,13 +164,6 @@ public static Collection data() throws Exception { return out; } -// public TFGraphTestAllSameDiff(Map inputs, Map predictions, String modelName, File localTestDir) { -// this.inputs = inputs; -// this.predictions = predictions; -// this.modelName = modelName; -// this.localTestDir = localTestDir; -// } - public TFGraphTestAllSameDiff(String name, TestCase tc){ this.modelName = name; this.testCase = tc; @@ -230,22 +205,9 @@ public void testOutputOnly() throws Exception { } } - Map inputs = new HashMap<>(); - Map predictions = new HashMap<>(); - - if(testCase.inputs != null){ - for(String s : testCase.inputs.keySet()){ - INDArray arr = TFGraphUtil.loadCsv(s, testCase); - inputs.put(s, arr); - } - } + Map inputs = TFGraphUtil.loadInputs(testCase); + Map predictions = TFGraphUtil.loadPredictions(testCase); - if(testCase.outputs != null){ - for(String s : testCase.outputs.keySet()){ - INDArray arr = TFGraphUtil.loadCsv(s, testCase); - predictions.put(s, arr); - } - } try { TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, verboseDebugMode); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java index 2c7f1e8575cd..655861eb2e9e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java @@ -29,10 +29,7 @@ import java.io.File; import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Map; +import java.util.*; /** * TFGraphTestAll* will run all the checked in TF graphs and @@ -55,7 +52,7 @@ public class TFGraphTestList { public static final boolean printArraysDebugging = false; public static String[] modelNames = new String[]{ - "bitcast/from_int32_to_uint32" + "arg_max/rank2_dim1" }; @After @@ -78,27 +75,32 @@ public static void beforeClass(){ } private String modelName; + private TestCase testCase; - @Parameterized.Parameters - public static Collection data() { - List modelNamesParams = new ArrayList<>(); - for (int i = 0; i < modelNames.length; i++) { - Object[] currentParams = new String[]{modelNames[i]}; - modelNamesParams.add(currentParams); + @Parameterized.Parameters(name="{0}") + public static Collection data() throws Exception { + + List out = new ArrayList<>(modelNames.length); + for(int i=0; i inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir); - Map predictions = TFGraphTestAllHelper.outputVars(modelName, MODEL_DIR, dir); + Map inputs = TFGraphUtil.loadInputs(testCase); + Map predictions = TFGraphUtil.loadPredictions(testCase); + Pair precisionOverride = TFGraphTestAllHelper.testPrecisionOverride(modelName); Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst()); Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java index de1d73c5521f..82b64cbf67b6 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java @@ -1,8 +1,10 @@ package org.nd4j.imports.TFGraphs; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.io.FileUtils; +import org.apache.commons.io.FilenameUtils; import org.nd4j.common.base.Preconditions; import org.nd4j.common.resources.Resources; import org.nd4j.common.tests.ResourceUtils; @@ -20,9 +22,21 @@ @Slf4j public class TFGraphUtil { - private TFGraphUtil(){ } + private TFGraphUtil() { + } + + public static TestCase getTestCase(String baseDir, String testName) throws Exception { + String newBase = FilenameUtils.concat(baseDir, testName + "/"); + Map cases = getTestCases(newBase, true); + Preconditions.checkState(cases.size() == 1, "Expected 1 test case, got %s", cases.size()); + return cases.get(cases.keySet().iterator().next()); + } + + public static Map getTestCases(String baseDir, boolean singleTest) throws Exception { - public static Map getTestCases(String baseDir) throws Exception { + baseDir = baseDir.replaceAll("\\\\", "/"); + if (!baseDir.endsWith("/")) + baseDir += "/"; long start = System.currentTimeMillis(); // String baseDir = "tf_graphs/examples/"; @@ -33,105 +47,114 @@ public static Map getTestCases(String baseDir) throws Exception Set modelSet = new HashSet<>(); - List modelFileNames = new ArrayList<>(); Map map = new HashMap<>(); long start2 = System.currentTimeMillis(); - for(String s : l){ - String sub = s.substring(baseDir.length()+1); -// int idx = sub.indexOf('/'); -// if(idx > 0){ -// String model -// } - - int idx = sub.lastIndexOf('/'); - - - if(idx > 0) { - String name = sub.substring(0, idx); - String modelDir = baseDir + sub.substring(0,idx+1); + for (String s : l) { + String sub = s.substring(baseDir.length()); + + int idx = singleTest ? 0 : sub.lastIndexOf('/'); + + boolean badTest = false; + String name = null; + String modelDir = null; + if (singleTest) { + name = ""; + modelDir = baseDir; + } else if (idx > 0) { + name = sub.substring(0, idx); + modelDir = baseDir + sub.substring(0, idx + 1); String expModel = modelDir + TFGraphTestAllSameDiff.MODEL_FILENAME; // while(!Resources.exists(expModel) && idx > 0){ - while(!listAsSet.contains(expModel) && idx > 0){ + while (!listAsSet.contains(expModel) && idx > 0) { //Due to a mixing of directories and variable names - we //For example we might have "X/frozen_model.pb" //And then also "X/something/or/other.csv //When this occurs - we should look up the path to determine which part is the model name // and which part is the variable name idx = sub.lastIndexOf('/', idx); - if(idx < 0){ + if (idx < 0) { System.out.println("***** BAD TEST DIRECTORY: " + s + " ******"); - continue; + badTest = true; + break; } sub = sub.substring(0, idx); - expModel = baseDir + "/" + sub + "/" + TFGraphTestAllSameDiff.MODEL_FILENAME; + expModel = baseDir + sub + "/" + TFGraphTestAllSameDiff.MODEL_FILENAME; + modelDir = baseDir + sub + "/"; + name = sub; } +// name = n; + } +// if(modelDir == null) +// continue; + if(badTest || modelDir == null) + continue; - modelSet.add(name); - TestCase tc = map.get(name); - if(tc == null){ - tc = new TestCase(name, null, null, null); - map.put(name, tc); - } + modelSet.add(name); + + TestCase tc = map.get(name); + if (tc == null) { + tc = new TestCase(name, null, null, null); + map.put(name, tc); + } - if(s.endsWith("prediction.csv")){ - if(tc.outputs == null) - tc.outputs = new HashMap<>(); - String varName = s.substring(modelDir.length()).replaceAll("____", "/"); + if (s.endsWith("prediction.csv")) { + if (tc.outputs == null) + tc.outputs = new HashMap<>(); + String varName = s.substring(modelDir.length()).replaceAll("____", "/"); // String varName = sub.substring(idx+1).replaceAll("____", "/"); - varName = varName.substring(0, varName.length() - "prediction.csv".length() - 1); - tc.outputs.put(varName, s); - } else if(s.endsWith("placeholder.shape")){ - if(tc.inputs == null) - tc.inputs = new HashMap<>(); + varName = varName.substring(0, varName.length() - "prediction.csv".length() - 1); + tc.outputs.put(varName, s); + } else if (s.endsWith("placeholder.csv")) { + if (tc.inputs == null) + tc.inputs = new HashMap<>(); // String varName = sub.substring(idx+1).replaceAll("____", "/"); - String varName = s.substring(modelDir.length()).replaceAll("____", "/"); - varName = varName.substring(0, varName.length() - "placeholder.shape".length() - 1); - tc.inputs.put(varName, s); - } else if(s.endsWith("/dtypes")){ - File f = Resources.asFile(s); - List lines = FileUtils.readLines(f, StandardCharsets.UTF_8); - tc.datatypes = new HashMap<>(); - for(String line : lines){ - String[] split = line.split(" "); - Preconditions.checkState(split.length == 2, "Expected 2 entries in dtypes file, got %s", split.length); - String key = split[0].replaceAll("____", "/"); - DataType value = ArrayOptionsHelper.dataType(split[1]); - - // adding zero output duplicate (if it doesn't exist) - if (key.endsWith(".0")) { - val nkey = key.replaceAll("\\.0$",""); - if (!tc.datatypes.containsKey(nkey)) { - tc.datatypes.put(nkey, value); - } - } else if (key.endsWith(":0")) { - val nkey = key.replaceAll(":0$",""); - if (!tc.datatypes.containsKey(nkey)) { - tc.datatypes.put(nkey, value); - } + String varName = s.substring(modelDir.length()).replaceAll("____", "/"); + varName = varName.substring(0, varName.length() - "placeholder.csv".length() - 1); + tc.inputs.put(varName, s); + } else if (s.endsWith("/dtypes")) { + File f = Resources.asFile(s); + List lines = FileUtils.readLines(f, StandardCharsets.UTF_8); + tc.datatypes = new HashMap<>(); + for (String line : lines) { + String[] split = line.split(" "); + Preconditions.checkState(split.length == 2, "Expected 2 entries in dtypes file, got %s", split.length); + String key = split[0].replaceAll("____", "/"); + DataType value = ArrayOptionsHelper.dataType(split[1]); + + // adding zero output duplicate (if it doesn't exist) + if (key.endsWith(".0")) { + val nkey = key.replaceAll("\\.0$", ""); + if (!tc.datatypes.containsKey(nkey)) { + tc.datatypes.put(nkey, value); + } + } else if (key.endsWith(":0")) { + val nkey = key.replaceAll(":0$", ""); + if (!tc.datatypes.containsKey(nkey)) { + tc.datatypes.put(nkey, value); } - - tc.datatypes.put(line, null); } + + tc.datatypes.put(line, null); } -// System.out.println(sub); } +// System.out.println(sub); } long end2 = System.currentTimeMillis(); - System.out.println("List duration: " + (end-start)); - System.out.println("Process duration: " + (end2-start2)); + System.out.println("List duration: " + (end - start)); + System.out.println("Process duration: " + (end2 - start2)); return map; } - private static long parseLong(String line){ + private static long parseLong(String line) { line = line.trim(); //Handle whitespace - if(line.matches("-?\\d+\\.0+")){ + if (line.matches("-?\\d+\\.0+")) { //Annoyingly, some integer data is stored with redundant/unnecessary zeros - like "-7.0000000" return Long.parseLong(line.substring(0, line.indexOf('.'))); } else { @@ -139,54 +162,60 @@ private static long parseLong(String line){ } } - private static double parseDouble(String line){ + private static double parseDouble(String line) { line = line.trim(); //Handle whitespace - some lines are like " -inf" - if("nan".equalsIgnoreCase(line)){ + if ("nan".equalsIgnoreCase(line)) { return Double.NaN; - } else if("inf".equalsIgnoreCase(line)) { + } else if ("inf".equalsIgnoreCase(line)) { return Double.POSITIVE_INFINITY; - } else if("-inf".equalsIgnoreCase(line)){ + } else if ("-inf".equalsIgnoreCase(line)) { return Double.NEGATIVE_INFINITY; } else { return Double.parseDouble(line); } } - private static boolean parseBoolean(String line){ + private static boolean parseBoolean(String line) { line = line.trim(); - if(line.matches("1(\\.0*)?")){ //Booleans are ocassionally represented like 1.000000 or 0.000000 + if (line.matches("1(\\.0*)?")) { //Booleans are ocassionally represented like 1.000000 or 0.000000 return true; - } else if(line.matches("0(\\.0*)?")){ + } else if (line.matches("0(\\.0*)?")) { return false; } return Boolean.parseBoolean(line); } - public static INDArray loadCsv(String path, TestCase tc) throws IOException { + public static INDArray loadCsv(String path, @NonNull TestCase tc) throws IOException { - DataType type = tc.datatypes.get(path); + DataType type; + if(tc.datatypes == null){ + log.warn("No datatype available for: {}", path); + type = DataType.FLOAT; + } else { + type = tc.datatypes.get(path); + } - String shapeFile = path.substring(0, path.length()-4) + ".shape"; + String shapeFile = path.substring(0, path.length() - 4) + ".shape"; List shapeLines = FileUtils.readLines(Resources.asFile(shapeFile), StandardCharsets.UTF_8); List filteredShape = new ArrayList<>(shapeLines.size()); - for(String s : shapeLines){ + for (String s : shapeLines) { String trimmed = s.trim(); - if(!trimmed.isEmpty()){ + if (!trimmed.isEmpty()) { filteredShape.add(trimmed); } } - if(type == null){ + if (type == null) { log.warn("DATATYPE NOT AVAILABLE FOR: {} - {}", tc.modelName, path); //Soon: this will be an exception type = DataType.FLOAT; } INDArray varValue = null; - if(filteredShape.size() == 0){ + if (filteredShape.size() == 0) { //Scalar String content = FileUtils.readFileToString(Resources.asFile(path), StandardCharsets.UTF_8); //IOUtils.toString(resources.get(i).getSecond().getInputStream(), StandardCharsets.UTF_8); - switch (type){ + switch (type) { case DOUBLE: case FLOAT: case HALF: @@ -216,7 +245,7 @@ public static INDArray loadCsv(String path, TestCase tc) throws IOException { } } else { int[] varShape = new int[filteredShape.size()]; - for( int j=0; j loadInputs(TestCase testCase) throws IOException { + Map inputs = null; + if(testCase.inputs != null){ + inputs = new HashMap<>(); + for(String s : testCase.inputs.keySet()){ + String path = testCase.inputs.get(s); + INDArray arr = TFGraphUtil.loadCsv(path, testCase); + inputs.put(s, arr); + } + } + return inputs; + } + + public static Map loadPredictions(TestCase testCase) throws IOException { + Map predictions = null; + if(testCase.outputs != null){ + predictions = new HashMap<>(); + for(String s : testCase.outputs.keySet()){ + String path = testCase.outputs.get(s); + INDArray arr = TFGraphUtil.loadCsv(path, testCase); + predictions.put(s, arr); + } + } + return predictions; + } } diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java index a71fe3939eec..233e99595ac5 100644 --- a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/common/tests/ResourceUtils.java @@ -61,6 +61,9 @@ public static List listClassPathFiles(String path, boolean recursive, bo private static List listClassPathFilesHelper(String path, boolean recursive, boolean includeDirectories, String... extensions) throws IOException { ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(new ClassPathResource(path).getClassLoader()); + if(path.contains("\\")) + path = path.replaceAll("\\\\", "/"); + if(!path.endsWith("/")) path = path + "/"; From afca40c27d5cf3762b3d454d96e563cbad035299 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 12 May 2020 17:10:41 +1000 Subject: [PATCH 3/8] Fixes Signed-off-by: Alex Black --- .../imports/TFGraphs/TFGraphTestAllHelper.java | 4 +++- .../nd4j/imports/TFGraphs/TFGraphTestList.java | 2 +- .../org/nd4j/imports/TFGraphs/TFGraphUtil.java | 18 ++++++------------ .../org/nd4j/imports/TFGraphs/TestCase.java | 17 +++++++++++++++++ 4 files changed, 27 insertions(+), 14 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestCase.java diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index b8a87b00c2c5..505cc195998e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -40,6 +40,7 @@ import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.TestCase; import org.nd4j.common.base.Preconditions; +import org.nd4j.common.resources.Resources; import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.listeners.ExecPrintListener; @@ -392,7 +393,8 @@ public static Pair> getGraphAfterExec(String base ExecuteWith executeWith, BiFunction graphLoaderFunction, List listeners, Set requiredOutputs, boolean printArraysDebugging) throws IOException { log.info("RUNNING TEST {}...", modelName); - SameDiff graph = graphLoaderFunction.apply(new ClassPathResource(baseDir + "/" + modelName + "/" + modelFilename).getFile(), modelName); + File f = Resources.asFile(baseDir + "/" + modelName + "/" + modelFilename); + SameDiff graph = graphLoaderFunction.apply(f, modelName); if(listeners != null){ graph.setListeners(listeners); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java index 655861eb2e9e..a6601418143c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java @@ -52,7 +52,7 @@ public class TFGraphTestList { public static final boolean printArraysDebugging = false; public static String[] modelNames = new String[]{ - "arg_max/rank2_dim1" + "emptyArrayTests/identity_n/rank1" }; @After diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java index 82b64cbf67b6..c66229e0c482 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java @@ -138,9 +138,9 @@ public static Map getTestCases(String baseDir, boolean singleT if (!tc.datatypes.containsKey(nkey)) { tc.datatypes.put(nkey, value); } + } else { + tc.datatypes.put(split[0], value); } - - tc.datatypes.put(line, null); } } // System.out.println(sub); @@ -185,14 +185,14 @@ private static boolean parseBoolean(String line) { return Boolean.parseBoolean(line); } - public static INDArray loadCsv(String path, @NonNull TestCase tc) throws IOException { + public static INDArray loadCsv(String path, String varName, @NonNull TestCase tc) throws IOException { DataType type; if(tc.datatypes == null){ log.warn("No datatype available for: {}", path); type = DataType.FLOAT; } else { - type = tc.datatypes.get(path); + type = tc.datatypes.get(varName); } String shapeFile = path.substring(0, path.length() - 4) + ".shape"; @@ -205,12 +205,6 @@ public static INDArray loadCsv(String path, @NonNull TestCase tc) throws IOExcep } } - if (type == null) { - log.warn("DATATYPE NOT AVAILABLE FOR: {} - {}", tc.modelName, path); - //Soon: this will be an exception - type = DataType.FLOAT; - } - INDArray varValue = null; if (filteredShape.size() == 0) { //Scalar @@ -349,7 +343,7 @@ public static Map loadInputs(TestCase testCase) throws IOExcept inputs = new HashMap<>(); for(String s : testCase.inputs.keySet()){ String path = testCase.inputs.get(s); - INDArray arr = TFGraphUtil.loadCsv(path, testCase); + INDArray arr = TFGraphUtil.loadCsv(path, s, testCase); inputs.put(s, arr); } } @@ -362,7 +356,7 @@ public static Map loadPredictions(TestCase testCase) throws IOE predictions = new HashMap<>(); for(String s : testCase.outputs.keySet()){ String path = testCase.outputs.get(s); - INDArray arr = TFGraphUtil.loadCsv(path, testCase); + INDArray arr = TFGraphUtil.loadCsv(path, s, testCase); predictions.put(s, arr); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestCase.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestCase.java new file mode 100644 index 000000000000..275ad1d37a9b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestCase.java @@ -0,0 +1,17 @@ +package org.nd4j.imports.TFGraphs; + +import lombok.AllArgsConstructor; +import lombok.Data; +import org.nd4j.linalg.api.buffer.DataType; + +import java.util.Map; + +@AllArgsConstructor +@Data +public class TestCase { + public String modelName; +// public String dir; + public Map inputs; //Key: variable name, values: filename (.csv) + public Map outputs; + public Map datatypes; +} From 94350cf16518331e74b65ee3e3b95196e2ee44f5 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 12 May 2020 18:01:06 +1000 Subject: [PATCH 4/8] Add test runner for TF2 import tests - TF2ImportTestsSameDiff Signed-off-by: Alex Black --- .../TFGraphs/TF2ImportTestsSameDiff.java | 156 ++++++++++++++++++ .../TFGraphs/TFGraphTestAllSameDiff.java | 3 - 2 files changed, 156 insertions(+), 3 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java new file mode 100644 index 000000000000..368fa24681f5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java @@ -0,0 +1,156 @@ +/* ****************************************************************************** + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.imports.TFGraphs; + +import lombok.extern.slf4j.Slf4j; +import org.junit.*; +import org.junit.rules.TestWatcher; +import org.junit.runner.Description; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.OpValidationSuite; +import org.nd4j.common.primitives.Pair; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.executioner.OpExecutioner; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.*; + +@Slf4j +@RunWith(Parameterized.class) +@Ignore //AB 2020/05/12 - Disabled until TF 2.x import test resources are available +public class TF2ImportTestsSameDiff { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests + + @Rule + public TestWatcher testWatcher = new TestWatcher() { + + @Override + protected void starting(Description description){ + log.info("TF2ImportTestsSameDiff: Starting parameterized test: " + description.getDisplayName()); + } + + //protected void failed(Throwable e, Description description) { + //protected void succeeded(Description description) { + }; + + private String modelName; + private TestCase testCase; + private String baseDir; + + private static final TFGraphTestAllHelper.ExecuteWith EXECUTE_WITH = TFGraphTestAllHelper.ExecuteWith.SAMEDIFF; + public static final String[] BASE_DIRS = new String[]{"tf_graphs/examples2.1"}; //Add directories for any other TensorFlow versions here + public static final String MODEL_FILENAME = "frozen_model.pb"; + + public static final String[] IGNORE_REGEXES = new String[]{ + + }; + + /* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have + all arrays printed during execution. + If a test name matches any regex here, an ExecPrintListener will be added to the listeners, and all output + arrays will be printed during execution + */ + private final List debugModeRegexes = null; //Arrays.asList("resize_nearest_neighbor/.*", "add_n.*"); + + @BeforeClass + public static void beforeClass() { + Nd4j.setDataType(DataType.FLOAT); + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); + } + + @Before + public void setup() { + Nd4j.setDataType(DataType.FLOAT); + Nd4j.getExecutioner().enableDebugMode(false); + Nd4j.getExecutioner().enableVerboseMode(false); + } + + @Parameterized.Parameters(name="{3}") + public static Collection data() throws Exception { + List out = new ArrayList<>(); + + for(String dir : BASE_DIRS) { + String version = dir.replaceAll("tf_graphs/examples", "tf"); + Map m = TFGraphUtil.getTestCases(dir, false); + List l = new ArrayList<>(m.keySet()); + Collections.sort(l); + for (String s : l) { + out.add(new Object[]{s, m.get(s), dir, version + "/" + s}); + } + } + return out; + } + + public TF2ImportTestsSameDiff(String name, TestCase tc, String baseDir, String displayName){ + this.modelName = name; + this.testCase = tc; + this.baseDir = baseDir; + } + + @Test + public void testOutputOnly() throws Exception { + if(TFGraphTestZooModels.isPPC()){ + /* + Ugly hack to temporarily disable tests on PPC only on CI + Issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7657 + These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions + */ + + log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/deeplearning4j/deeplearning4j/issues/7657"); + OpValidationSuite.ignoreFailing(); + } + + + Nd4j.create(1); + + for(String s : IGNORE_REGEXES){ + if(modelName.matches(s)){ + log.info("\n\tIGNORE MODEL ON REGEX: {} - regex {}", modelName, s); + OpValidationSuite.ignoreFailing(); + } + } + Pair precisionOverride = TFGraphTestAllHelper.testPrecisionOverride(modelName); + Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst()); + Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); + + boolean verboseDebugMode = false; + if(debugModeRegexes != null){ + for(String regex : debugModeRegexes){ + if(modelName.matches(regex)){ + verboseDebugMode = true; + break; + } + } + } + + Map inputs = TFGraphUtil.loadInputs(testCase); + Map predictions = TFGraphUtil.loadPredictions(testCase); + + + try { + TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, baseDir, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, verboseDebugMode); + //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir); + } catch (Throwable t){ + log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t); + throw t; + } + //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, EXECUTE_WITH); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 778e324aa61c..919bbe6fec91 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -54,10 +54,7 @@ protected void starting(Description description){ //protected void succeeded(Description description) { }; -// private Map inputs; -// private Map predictions; private String modelName; - private File localTestDir; private TestCase testCase; private static final TFGraphTestAllHelper.ExecuteWith EXECUTE_WITH = TFGraphTestAllHelper.ExecuteWith.SAMEDIFF; From 3a92ec2ae3cd5b3bf19f70defcc095a00a2c18cc Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 12 May 2020 18:15:59 +1000 Subject: [PATCH 5/8] Copyright header and clean up Signed-off-by: Alex Black --- .../nd4j/imports/TFGraphs/TFGraphUtil.java | 35 ++++++++----------- .../org/nd4j/imports/TFGraphs/TestCase.java | 16 ++++++++- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java index c66229e0c482..790118ae4966 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.imports.TFGraphs; import lombok.NonNull; @@ -107,13 +122,11 @@ public static Map getTestCases(String baseDir, boolean singleT if (tc.outputs == null) tc.outputs = new HashMap<>(); String varName = s.substring(modelDir.length()).replaceAll("____", "/"); -// String varName = sub.substring(idx+1).replaceAll("____", "/"); varName = varName.substring(0, varName.length() - "prediction.csv".length() - 1); tc.outputs.put(varName, s); } else if (s.endsWith("placeholder.csv")) { if (tc.inputs == null) tc.inputs = new HashMap<>(); -// String varName = sub.substring(idx+1).replaceAll("____", "/"); String varName = s.substring(modelDir.length()).replaceAll("____", "/"); varName = varName.substring(0, varName.length() - "placeholder.csv".length() - 1); tc.inputs.put(varName, s); @@ -244,24 +257,6 @@ public static INDArray loadCsv(String path, String varName, @NonNull TestCase tc } try { -// String content; -// Pair p = resources.get(i); -// boolean isRef = p.getSecond().isFile() && !p.getSecond().exists(); -// -// InputStream stream; -// if(isRef){ -// //Slight hack for loading strumpf reference files -// File r = new StrumpfResolver().localCacheRoot(); -// String path = p.getSecond().getFile() + StrumpfResolver.REF; -// File f = ResourceFile.fromFile(path).localFile(r); -// stream = new BufferedInputStream(new FileInputStream(f)); -// } else { -// stream = new BufferedInputStream(resources.get(i).getSecond().getInputStream()); -// } -// -// try(InputStream is = stream){ -// content = String.join("\n", IOUtils.readLines(is, StandardCharsets.UTF_8)); -// } String content = FileUtils.readFileToString(Resources.asFile(path), StandardCharsets.UTF_8); if (content.isEmpty()) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestCase.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestCase.java index 275ad1d37a9b..d7b23660a56f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestCase.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TestCase.java @@ -1,3 +1,18 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ package org.nd4j.imports.TFGraphs; import lombok.AllArgsConstructor; @@ -10,7 +25,6 @@ @Data public class TestCase { public String modelName; -// public String dir; public Map inputs; //Key: variable name, values: filename (.csv) public Map outputs; public Map datatypes; From 35a3def2986639bf23c6c8fca3734c4bea178ccd Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 12 May 2020 18:37:49 +1000 Subject: [PATCH 6/8] Cleanup and copyright headers Signed-off-by: Alex Black --- .../org/nd4j/AssertTestsExtendBaseClass.java | 4 +- .../TFGraphs/TF2ImportTestsSameDiff.java | 2 +- .../TFGraphs/TFGraphTestAllHelper.java | 472 +----------------- .../TFGraphs/TFGraphTestAllLibnd4j.java | 188 ------- .../TFGraphs/TFGraphTestAllSameDiff.java | 2 +- .../imports/TFGraphs/TFGraphTestList.java | 10 +- .../TFGraphs/TFGraphTestZooModels.java | 34 +- .../nd4j/imports/TFGraphs/TFGraphUtil.java | 26 +- 8 files changed, 36 insertions(+), 702 deletions(-) delete mode 100644 nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java index 6414bec2f3f0..defc5bdee102 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/AssertTestsExtendBaseClass.java @@ -18,7 +18,7 @@ import lombok.extern.slf4j.Slf4j; import org.nd4j.common.tests.AbstractAssertTestsClass; import org.nd4j.common.tests.BaseND4JTest; -import org.nd4j.imports.TFGraphs.TFGraphTestAllLibnd4j; +import org.nd4j.imports.TFGraphs.TF2ImportTestsSameDiff; import org.nd4j.imports.TFGraphs.TFGraphTestAllSameDiff; import org.nd4j.imports.TFGraphs.TFGraphTestList; import org.nd4j.imports.TFGraphs.TFGraphTestZooModels; @@ -41,7 +41,7 @@ protected Set> getExclusions() { //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) return new HashSet<>(Arrays.asList( TFGraphTestAllSameDiff.class, - TFGraphTestAllLibnd4j.class, + TF2ImportTestsSameDiff.class, TFGraphTestList.class, TFGraphTestZooModels.class, ImportModelDebugger.class //Run manually only, otherwise ignored diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java index 368fa24681f5..bf6628ed2eb3 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java @@ -87,7 +87,7 @@ public static Collection data() throws Exception { for(String dir : BASE_DIRS) { String version = dir.replaceAll("tf_graphs/examples", "tf"); - Map m = TFGraphUtil.getTestCases(dir, false); + Map m = TFGraphUtil.getTestCases(dir, false, MODEL_FILENAME); List l = new ArrayList<>(m.keySet()); Collections.sort(l); for (String s : l) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java index 505cc195998e..bf7eb8accd80 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllHelper.java @@ -19,29 +19,25 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.apache.commons.io.FilenameUtils; -import org.apache.commons.io.IOUtils; -import org.apache.commons.lang3.math.NumberUtils; import org.junit.After; import org.junit.Before; import org.junit.BeforeClass; import org.nd4j.autodiff.execution.NativeGraphExecutioner; - import org.nd4j.autodiff.execution.conf.ExecutionMode; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; -import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.internal.InferenceSession; -import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.memory.ArrayCloseMemoryMgr; import org.nd4j.autodiff.samediff.internal.memory.CloseValidationMemoryMgr; import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.TestCase; import org.nd4j.common.base.Preconditions; +import org.nd4j.common.function.BiFunction; +import org.nd4j.common.io.ClassPathResource; +import org.nd4j.common.primitives.Pair; import org.nd4j.common.resources.Resources; -import org.nd4j.imports.TFGraphs.listener.OpExecOrderListener; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.listeners.ExecPrintListener; import org.nd4j.linalg.api.buffer.DataType; @@ -49,33 +45,22 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; -import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.function.BiFunction; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; -import org.nd4j.common.io.ClassPathResource; import org.nd4j.linalg.ops.transforms.Transforms; -import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.string.NDArrayStrings; -import org.nd4j.common.util.ArrayUtil; import org.nd4j.nativeblas.NativeOpsHolder; -import org.nd4j.common.resources.strumpf.ResourceFile; -import org.nd4j.common.resources.strumpf.StrumpfResolver; -import org.springframework.core.io.FileSystemResource; import org.springframework.core.io.Resource; import org.springframework.core.io.support.PathMatchingResourcePatternResolver; import org.springframework.core.io.support.ResourcePatternResolver; import java.io.*; -import java.net.URI; import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; import java.util.*; import java.util.regex.Pattern; import static org.junit.Assert.*; -import static org.nd4j.imports.TFGraphs.TFGraphsSkipNodes.skipNode; /** * Created by susaneraly on 11/6/17. @@ -125,20 +110,6 @@ public void tearDown() { .outputMode(OutputMode.VARIABLE_SPACE) .build(); - protected static List fetchTestParams(String baseDir, String modelFileName, ExecuteWith executeWith, File localTestDir) throws IOException { - String[] modelNames = modelDirNames(baseDir, executeWith, modelFileName); - List modelParams = new ArrayList<>(); - for (int i = 0; i < modelNames.length; i++) { - Object[] currentParams = new Object[4]; - currentParams[0] = inputVars(modelNames[i], baseDir, localTestDir); //input variable map - could be null - currentParams[1] = outputVars(modelNames[i], baseDir, localTestDir); //saved off predictions - currentParams[2] = modelNames[i]; - currentParams[3] = localTestDir; - modelParams.add(currentParams); - } - return modelParams; - } - protected static void checkOnlyOutput(Map inputs, Map predictions, String modelName, String baseDir, String modelFilename, ExecuteWith execType, BiFunction loader, Double maxRelErrorOverride, Double minAbsErrorOverride, boolean printArraysDebugging) throws IOException { @@ -298,97 +269,6 @@ protected static void checkOnlyOutput(Map inputs, Map inputs, String modelName, String baseDir, String modelFileName, - ExecuteWith execType, File localTestDir, boolean printArraysDebugging) throws IOException { - checkIntermediate(inputs, modelName, baseDir, modelFileName, execType, LOADER, null, null, localTestDir, printArraysDebugging); - } - - public static void checkIntermediate(Map inputs, String modelName, String baseDir, String modelFileName, - ExecuteWith execType, BiFunction loader, - Double maxRelErrorOverride, Double minAbsErrorOverride, File localTestDir, boolean printArraysDebugging) throws IOException { - Preconditions.checkArgument((maxRelErrorOverride == null) == (minAbsErrorOverride == null), "Both maxRelErrorOverride and minAbsErrorOverride" + - " must be null or both must be provided"); - Nd4j.EPS_THRESHOLD = 1e-3; - OpExecOrderListener listener = new OpExecOrderListener(); //Used to collect exec order - Pair> p = getGraphAfterExec(baseDir, modelFileName, modelName, inputs, execType, loader, Collections.singletonList(listener), null, printArraysDebugging); - SameDiff graph = p.getFirst(); - Map sdPredictions = p.getSecond(); - - //Collect coverage info about ops - OpValidation.collectTensorflowImportCoverage(graph); - - if (!execType.equals(ExecuteWith.JUST_PRINT)) { - int count = 0; - //Evaluate the nodes in their execution order - this is useful for debugging (as we want the *first* failure - // to be detected before later failures) - List varNames = new ArrayList<>(); - Map fns = graph.getOps(); - List execOrder = listener.getOpNamesList(); - for(String opName : execOrder){ - String[] outputs = graph.getOutputsForOp(fns.get(opName).getOp()); - Collections.addAll(varNames, outputs); - } - - for (String varName : varNames) { - if (!inputs.containsKey(varName)) { //avoiding placeholders - INDArray tfValue = intermediateVars(modelName, baseDir, varName, localTestDir); - if (tfValue == null) { - continue; - } - log.info("Starting check: variable {}", varName); - if (skipNode(modelName, varName)) { - log.info("\n\tFORCING no check on " + varName); - } else { - assertArrayEquals("Shape not equal on node " + varName, tfValue.shape(), graph.getVariable(varName).getShape()); - INDArray sdVal = sdPredictions.get(varName); - if(maxRelErrorOverride != null){ - INDArray diff = Transforms.abs(tfValue.sub(sdVal), false); - INDArray absErrorMask = diff.gte(minAbsErrorOverride); //value 1 if x[i] > minAbsError; value 0 otherwise. Used to get rid of 1e-30 vs. 1e-29 type failures - INDArray sumAbs = Transforms.abs(tfValue, true).addi(Transforms.abs(sdVal, true)); - BooleanIndexing.replaceWhere(sumAbs, 1.0, Conditions.equals(0.0)); //Can only get 0.0 if both are zeros - need to avoid 0/0=NaN - INDArray relError = diff.divi(sumAbs); - relError.muli(absErrorMask); - - int countExceeds = Nd4j.getExecutioner().exec(new MatchCondition(relError, Conditions.greaterThan(maxRelErrorOverride))).getInt(0); - - double maxRE = -1; - //Mainly used for analysis in debugger: - DifferentialFunction op = null; - String[] opInputs = null; - if(countExceeds > 0){ - maxRE = relError.maxNumber().doubleValue(); - //Find the op that this variable is produced by - op = graph.getVariableOutputOp(varName); - opInputs = graph.getInputsForOp(op); - } - - - assertEquals( varName + ": " + countExceeds + " values exceed maxRelError=" + maxRelErrorOverride - + " with minAbsError=" + minAbsErrorOverride + "; largest observed relError=" + maxRE, 0, countExceeds); - } else { -// assertEquals("Value not equal on node " + varName, tfValue, sdVal); - if(tfValue.equals(sdVal)){ - System.out.println("Pass: " + varName); - } else { - System.out.println("FAIL: " + varName); - System.out.println("TF:\n" + tfValue); - System.out.println("SD:\n" + sdVal); - } - - } - log.info("Values and shapes equal for {}", varName); - count++; - } - - } - } - - assertTrue("No intermediate variables were checked", count > 0); - } - - Nd4j.EPS_THRESHOLD = 1e-5; - } - public static Pair> getGraphAfterExec(String baseDir, String modelFilename, String modelName, Map inputs, ExecuteWith executeWith, BiFunction graphLoaderFunction, List listeners, Set requiredOutputs, boolean printArraysDebugging) throws IOException { @@ -454,353 +334,7 @@ private static String[] modelDirNames(String base_dir, ExecuteWith executeWith, return exampleNames; } - protected static Map inputVars(String modelName, String base_dir, File localTestDir) throws IOException { - return readVars(modelName, base_dir, "**.placeholder", true, localTestDir); - } - - - protected static Map outputVars(String modelName, String base_dir, File localTestDir) throws IOException { - return readVars(modelName, base_dir, "**.prediction", true, localTestDir); - } - - protected static Map inbetweenVars(String modelName, String base_dir, File localTestDir) throws IOException { - return readVars(modelName, base_dir, "**.prediction_inbw", true, localTestDir); - } - - - //return readVars(modelName, base_dir, "**.prediction_inbw", true); - - /** - * Possible for a single node to give multiple outputs - * - * How is a node that has a list of outputs like in the case of "node_multiple_out" work - * Below is hardcoded for a single node - */ - protected static INDArray intermediateVars(String modelName, String base_dir, String varName, File localTestDir) throws IOException { - //convert varName to convention used in naming files - // "/" replaced by "____"; followed by a digit indicating the output number followed by prediction_inbw.(shape|csv) - if (varName.contains(":")) { - varName = varName.replace(':', '.'); - } else { - varName = varName + ".0"; - } - String name = varName.replaceAll("/", "____") + ".prediction_inbw"; - Map nodeSepOutput = readVars(modelName, base_dir, name, true, localTestDir); - - boolean importNameWorkaround = false; - if(nodeSepOutput.isEmpty()){ - //Edge case: intermediates were generated with help of import_graph_def method, which by default adds "import/" to names - // for some reason. https://www.tensorflow.org/api_docs/python/tf/graph_util/import_graph_def - //So many of earlier intermediate nodes test data were generated with filenames like "import___X..." instead of "X..." - name = "import____" + name; - nodeSepOutput = readVars(modelName, base_dir, name, true, localTestDir); - importNameWorkaround = true; - } - - //required check for pattern matching as there are scopes and "*" above is a greedy match - Set removeList = confirmPatternMatch(nodeSepOutput.keySet(), importNameWorkaround ? "import/" + varName : varName); - for (String toRemove : removeList) { - nodeSepOutput.remove(toRemove); - } - if(importNameWorkaround){ - return nodeSepOutput.get("import/" + varName); //this *should* return a list of the indarrays for each node - } else { - return nodeSepOutput.get(varName); //this *should* return a list of the indarrays for each node - } - } - - public static Set confirmPatternMatch(Set setOfNames, String varName) { - Set removeList = new HashSet<>(); - for (String name : setOfNames) { - if (name.equals(varName)) continue; - String[] splitByPeriod = name.split("\\."); - //not a number - maybe another variable deeper in the same scope - if (!NumberUtils.isNumber(splitByPeriod[splitByPeriod.length - 1])) { - removeList.add(name); - } else if (!String.join(".", Arrays.copyOfRange(splitByPeriod, 0, splitByPeriod.length - 1)).equals(varName)) { - removeList.add(name); - } - } - return removeList; - } - - - protected static Map readVars(String modelName, String base_dir, String pattern, boolean recursive, File localTestDir) throws IOException { - Map varMap = new HashMap<>(); - String modelDir = base_dir + "/" + modelName; - - // key is variable name, value is data type - val dtypes = new HashMap(); - - List> resources = new ArrayList<>(); - if(recursive){ - String nameRegex = pattern.replace("**.",".*\\.") + "\\.shape"; -// File baseDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString() + "/" + modelName); -// baseDir.mkdirs(); -// baseDir.deleteOnExit(); -// new ClassPathResource(modelDir).copyDirectory(baseDir); - - // checking out, if local folder declared - String localPath = System.getenv(TFGraphTestAllHelper.resourceFolderVar); - if(localPath != null && (!localPath.contains("src/main/resources") && !localPath.contains("src\\main\\resources"))){ - localPath = FilenameUtils.concat(localPath, "src/main/resources"); - } - - // baseDir will differ, depending on run mode - File baseDir = localPath == null ? new File(localTestDir, "extracted/" + modelName) : new File(localPath, base_dir + "/" + modelName); - String[] arr = baseDir.list(); - - if(!baseDir.exists() || arr == null || arr.length == 0){ - // we're skipping extraction if we're using local copy of dl4j-tests-resources - if (localPath == null) { - baseDir.mkdirs(); - baseDir.deleteOnExit(); - String md = modelDir; - if(!md.endsWith("/") && !md.endsWith("\\")){ - md = md + "/"; - } - new ClassPathResource(md).copyDirectory(baseDir); - } else{ - throw new IllegalStateException("local directory declared but could not find files: " + baseDir.getAbsolutePath()); - } - - } - - LinkedList queue = new LinkedList<>(); - queue.add(baseDir); - - while(!queue.isEmpty()){ - File subdir = queue.remove(); - File[] files = subdir.listFiles(); - if (files != null) { - for (File f : files) { - if (f.isDirectory()) { - queue.add(f); - } else { - String filename = f.getName(); - if(filename.matches(nameRegex)){ - File csvFile = new File(f.getAbsolutePath().replace(".shape",".csv")); - resources.add(new Pair<>(new FileSystemResource(f), new FileSystemResource(csvFile))); - } else if (filename.equals("dtypes")) { - List stringList; - - try (val is = new BufferedInputStream(new FileInputStream(f))) { - stringList = IOUtils.readLines(is, StandardCharsets.UTF_8); - - for (val s:stringList) { - val split = s.split("\\ "); - - val okey = split[0].replaceAll("____", "/"); - // adopt / in names - val key = modelDir + "/" + okey; - - // parse type directly - DataType value = ArrayOptionsHelper.dataType(split[1]); - - // adding key directly - //if (dtypes.containsKey(key)) - // throw new ND4JIllegalStateException("Specified key already exist: [" + key + "]"); - //else - - dtypes.put(key, value); - - // adding zero output duplicate (if it doesn't exist) - if (key.endsWith(".0")) { - val nkey = key.replaceAll("\\.0$",""); - if (!dtypes.containsKey(nkey)) { - dtypes.put(nkey, value); - } - } else if (key.endsWith(":0")) { - val nkey = key.replaceAll(":0$",""); - if (!dtypes.containsKey(nkey)) { - dtypes.put(nkey, value); - } - } - } - } catch (FileNotFoundException e) { - stringList = new ArrayList<>(); - } - } - } - } - } - } - } else { - ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(new ClassPathResource(modelDir).getClassLoader()); - Resource[] r = resolver.getResources("classpath*:" + modelDir + "/" + pattern + ".shape"); - for(Resource res : r){ - String fileName = res.getFilename(); - String varPath = modelDir + "/" + fileName; - Resource r2 = new org.springframework.core.io.ClassPathResource(varPath.replace(".shape", ".csv")); - resources.add(new Pair<>(res, r2)); - } - - } - -// Preconditions.checkState(!dtypes.isEmpty(), "No datatypes file was found"); - val dtype = Nd4j.dataType(); - for (int i = 0; i < resources.size(); i++) { - URI u = resources.get(i).getFirst().getURI(); - String varName = u.toString(); - int idx = varName.indexOf(modelName); - varName = varName.substring(idx + modelName.length()+1); //+1 for "/" - varName = varName.replaceAll("____","/"); - varName = varName.replaceAll(".placeholder.shape",""); - varName = varName.replaceAll(".prediction.shape",""); - varName = varName.replaceAll(".prediction_inbw.shape",""); - - DataType type = dtypes.get(modelDir + "/" + varName); - - List lines; //= FileUtils.readLines(new ClassPathResource(varPath).getFile(), Charset.forName("UTF-8")); - try(InputStream is = new BufferedInputStream(resources.get(i).getFirst().getInputStream())){ - lines = IOUtils.readLines(is, StandardCharsets.UTF_8); - } - List filtered = new ArrayList<>(lines.size()); - for(String s : lines){ - String trimmed = s.trim(); - if(!trimmed.isEmpty()){ - filtered.add(trimmed); - } - } - - if(type == null){ - log.warn("DATATYPE NOT AVAILABLE FOR: {} - {}", modelName, varName); - //Soon: this will be an exception - type = DataType.FLOAT; - } - - INDArray varValue; - if(filtered.size() == 0){ - //Scalar - String content = IOUtils.toString(resources.get(i).getSecond().getInputStream(), StandardCharsets.UTF_8); - switch (type){ - case DOUBLE: - case FLOAT: - case HALF: - case BFLOAT16: - varValue = Nd4j.scalar(type, parseDouble(content)); - break; - case LONG: - case INT: - case SHORT: - case UBYTE: - case BYTE: - case UINT16: - case UINT32: - case UINT64: - varValue = Nd4j.scalar(type, parseLong(content)); - break; - case BOOL: - varValue = Nd4j.scalar(parseBoolean(content)); - break; - case UTF8: - varValue = Nd4j.scalar(content); - break; - case COMPRESSED: - case UNKNOWN: - default: - throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type); - } - } else { - int[] varShape = new int[filtered.size()]; - for( int j=0; j p = resources.get(i); - boolean isRef = p.getSecond().isFile() && !p.getSecond().exists(); - - InputStream stream; - if(isRef){ - //Slight hack for loading strumpf reference files - File r = new StrumpfResolver().localCacheRoot(); - String path = p.getSecond().getFile() + StrumpfResolver.REF; - File f = ResourceFile.fromFile(path).localFile(r); - stream = new BufferedInputStream(new FileInputStream(f)); - } else { - stream = new BufferedInputStream(resources.get(i).getSecond().getInputStream()); - } - - try(InputStream is = stream){ - content = String.join("\n", IOUtils.readLines(is, StandardCharsets.UTF_8)); - } - - if (content.isEmpty()) { - //Should be zeros in shape - boolean foundZero = false; - for( int s : varShape){ - foundZero |= (s == 0); - } - if(foundZero){ - varValue = Nd4j.create(type, ArrayUtil.toLongArray(varShape)); - } else { - throw new IllegalStateException("Empty data but non-empty shape: " + resources.get(i).getSecond()); - } - } else { - if(varShape.length == 1 && varShape[0] == 0) //Annoyingly, some scalars have shape [0] instead of [] - varShape = new int[0]; - - String[] cLines = content.split("\n"); - switch (type){ - case DOUBLE: - case FLOAT: - case HALF: - case BFLOAT16: - double[] dArr = new double[cLines.length]; - int x=0; - while(x < dArr.length){ - dArr[x] = parseDouble(cLines[x]); - x++; - } - varValue = Nd4j.createFromArray(dArr).castTo(type).reshape('c', varShape); - break; - case LONG: - case INT: - case SHORT: - case UBYTE: - case BYTE: - case UINT16: - case UINT32: - case UINT64: - long[] lArr = new long[cLines.length]; - int y=0; - while(y < lArr.length){ - lArr[y] = parseLong(cLines[y]); - y++; - } - varValue = Nd4j.createFromArray(lArr).castTo(type).reshape('c', varShape); - break; - case BOOL: - boolean[] bArr = new boolean[cLines.length]; - int z=0; - while(z < bArr.length){ - bArr[z] = parseBoolean(cLines[z]); - z++; - } - varValue = Nd4j.createFromArray(bArr).reshape('c', varShape); - break; - case UTF8: - varValue = Nd4j.create(cLines).reshape('c', varShape); - break; - case COMPRESSED: - case UNKNOWN: - default: - throw new UnsupportedOperationException("Unknown / not implemented datatype: " + type); - } - } - } catch (NumberFormatException e) { - log.warn("Error parsing number", e); - continue; - } - } - - varMap.put(varName, varValue); - } - return varMap; - } private static long parseLong(String line){ line = line.trim(); //Handle whitespace diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java deleted file mode 100644 index 9b6b0f372846..000000000000 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllLibnd4j.java +++ /dev/null @@ -1,188 +0,0 @@ -/* ****************************************************************************** - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2019 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.imports.TFGraphs; - -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.junit.*; -import org.junit.rules.TestWatcher; -import org.junit.runner.Description; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.OpValidationSuite; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; -import org.nd4j.nativeblas.NativeOpsHolder; - -import java.io.File; -import java.io.IOException; -import java.util.*; - -/** - * Created by susaneraly on 11/29/17. - */ -@RunWith(Parameterized.class) -@Slf4j -@Ignore("AB 2019/05/21 - JVM Crashes - Issue #7657") -public class TFGraphTestAllLibnd4j { //Note: Can't extend BaseNd4jTest here as we need no-arg constructor for parameterized tests - - @Rule - public TestWatcher testWatcher = new TestWatcher() { - - @Override - protected void starting(Description description){ - log.info("TFGraphTestAllLibnd4j: Starting parameterized test: " + description.getDisplayName()); - } - - //protected void failed(Throwable e, Description description) { - //protected void succeeded(Description description) { - }; - - private Map inputs; - private Map predictions; - private String modelName; - private File localTestDir; - - private static final TFGraphTestAllHelper.ExecuteWith EXECUTE_WITH = TFGraphTestAllHelper.ExecuteWith.LIBND4J; - private static final String BASE_DIR = "tf_graphs/examples"; - private static final String MODEL_FILENAME = "frozen_model.pb"; - - private static final String[] SKIP_FOR_LIBND4J_EXEC = new String[]{ - //Exceptions - need to look into: - "alpha_dropout/.*", - "layers_dropout/.*", - //"losses/.*", - - //These can't pass until this is fixed: https://github.com/deeplearning4j/deeplearning4j/issues/6465#issuecomment-424209155 - //i.e., reduction ops with newFormat/keepDims args - //"l2_normalize/.*", - //"norm_tests/.*", - "g_06", - - //JVM crashes - "simpleif.*", - "simple_cond.*", - - //2019/01/24 - Failing - "cond/cond_true", - "simplewhile_.*", - "simple_while", - "while1/.*", - "while2/a", - - //2019/01/24 - TensorArray support missing at libnd4j exec level?? - "tensor_array/.*", - - //2019/02/04 - Native execution exception: "Graph wasn't toposorted" - "primitive_gru_dynamic", - - //2019/02/08 - Native execution exception: "Graph wasn't toposorted". Note it's only the dynamic (while loop) RNNs - "rnn/basiclstmcell/dynamic.*", - "rnn/basicrnncell/dynamic.*", - "rnn/bidir_basic/dynamic.*", - "rnn/fused_adapt_basic/dynamic.*", - "rnn/grucell/dynamic.*", - "rnn/lstmcell/dynamic.*", - "rnn/srucell/dynamic.*", - - //2019/02/23 Passing for SameDiff exec, failing for libnd4j exec - "rnn/grublockcellv2/.*", - "rnn/lstmblockcell/.*", - "rnn/lstmblockfusedcell/.*", - }; - - @BeforeClass - public static void beforeClass() { - Nd4j.setDataType(DataType.FLOAT); - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); - } - - @Before - public void setup(){ - Nd4j.setDataType(DataType.FLOAT); - } - - @After - public void tearDown() { - NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(false); - NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(false); - } - - @Parameterized.Parameters(name="{2}") - public static Collection data() throws IOException { - val localPath = System.getenv(TFGraphTestAllHelper.resourceFolderVar); - - // if this variable isn't set - we're using dl4j-tests-resources - if (localPath == null) { - File baseDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); - return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); - } else { - File baseDir = new File(localPath); - return TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, baseDir); - } - } - - public TFGraphTestAllLibnd4j(Map inputs, Map predictions, String modelName, File localTestDir) { - this.inputs = inputs; - this.predictions = predictions; - this.modelName = modelName; - this.localTestDir = localTestDir; - } - - @Test//(timeout = 25000L) - public void test() throws Exception { - if(TFGraphTestZooModels.isPPC()){ - /* - Ugly hack to temporarily disable tests on PPC only on CI - Issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7657 - These will be re-enabled for PPC once fixed - in the mean time, remaining tests will be used to detect and prevent regressions - */ - - log.warn("TEMPORARILY SKIPPING TEST ON PPC ARCHITECTURE DUE TO KNOWN JVM CRASH ISSUES - SEE https://github.com/deeplearning4j/deeplearning4j/issues/7657"); - OpValidationSuite.ignoreFailing(); - } - - Nd4j.create(1); - for(String s : TFGraphTestAllSameDiff.IGNORE_REGEXES){ - if(modelName.matches(s)){ - log.info("\n\tIGNORE MODEL ON REGEX: {} - regex {}", modelName, s); - OpValidationSuite.ignoreFailing(); - } - } - - for(String s : SKIP_FOR_LIBND4J_EXEC){ - if(modelName.matches(s)){ - log.info("\n\tIGNORE MODEL ON REGEX - SKIP LIBND4J EXEC ONLY: {} - regex {}", modelName, s); - OpValidationSuite.ignoreFailing(); - } - } - - log.info("Starting test: {}", this.modelName); - Pair precisionOverride = TFGraphTestAllHelper.testPrecisionOverride(modelName); - Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst()); - Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond()); - - TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, - TFGraphTestAllHelper.LOADER, maxRE, minAbs, false); - //TFGraphTestAllHelper.checkIntermediate(inputs, modelName, EXECUTE_WITH); - } - -} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 919bbe6fec91..7b70b7f8c02d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -150,7 +150,7 @@ public void tearDown() { @Parameterized.Parameters(name="{0}") public static Collection data() throws Exception { - Map m = TFGraphUtil.getTestCases(BASE_DIR, false); + Map m = TFGraphUtil.getTestCases(BASE_DIR, false, MODEL_FILENAME); List l = new ArrayList<>(m.keySet()); Collections.sort(l); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java index a6601418143c..1e814fd88a4d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java @@ -83,7 +83,7 @@ public static Collection data() throws Exception { List out = new ArrayList<>(modelNames.length); for(int i=0; i inputs = TFGraphTestAllHelper.inputVars(modelName, MODEL_DIR, dir); - TFGraphTestAllHelper.checkIntermediate(inputs, modelName, MODEL_DIR, MODEL_FILENAME, executeWith, dir, printArraysDebugging); - } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java index d8550673e95b..b7ca3e3d3dfa 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestZooModels.java @@ -43,10 +43,7 @@ import java.io.IOException; import java.net.URL; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Map; +import java.util.*; @RunWith(Parameterized.class) @Slf4j @@ -103,9 +100,8 @@ public class TFGraphTestZooModels { //Note: Can't extend BaseNd4jTest here as we private static final String BASE_DIR = "tf_graphs/zoo_models"; private static final String MODEL_FILENAME = "tf_model.txt"; - private Map inputs; - private Map predictions; private String modelName; + private TestCase testCase; private File localTestDir; public static String getBaseModelDir(){ @@ -206,18 +202,26 @@ public static void beforeClass(){ Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); } - @Parameterized.Parameters(name="{2}") - public static Collection data() throws IOException { + @Parameterized.Parameters(name="{0}") + public static Collection data() throws Exception { + Map m = TFGraphUtil.getTestCases(BASE_DIR, false, MODEL_FILENAME); + classTestDir.create(); File baseDir = classTestDir.newFolder(); // new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); - List params = TFGraphTestAllHelper.fetchTestParams(BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF, baseDir); - return params; + + List out = new ArrayList<>(); + List l = new ArrayList<>(m.keySet()); + Collections.sort(l); + for (String s : l) { + out.add(new Object[]{s, m.get(s), baseDir}); + } + + return out; } - public TFGraphTestZooModels(Map inputs, Map predictions, String modelName, File localTestDir) { - this.inputs = inputs; - this.predictions = predictions; + public TFGraphTestZooModels(String modelName, TestCase tc, File localTestDir) { this.modelName = modelName; + this.testCase = tc; this.localTestDir = localTestDir; } @@ -266,6 +270,10 @@ public void testOutputOnly() throws Exception { Double maxRE = 1e-3; Double minAbs = 1e-4; currentTestDir = testDir.newFolder(); + + Map inputs = TFGraphUtil.loadInputs(testCase); + Map predictions = TFGraphUtil.loadPredictions(testCase); + log.info("----- SameDiff Exec: {} -----", modelName); TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, TFGraphTestAllHelper.ExecuteWith.SAMEDIFF, LOADER, maxRE, minAbs, false); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java index 790118ae4966..2dbd34de9a15 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java @@ -40,28 +40,24 @@ public class TFGraphUtil { private TFGraphUtil() { } - public static TestCase getTestCase(String baseDir, String testName) throws Exception { + public static TestCase getTestCase(String baseDir, String testName, String modelFilename) throws Exception { String newBase = FilenameUtils.concat(baseDir, testName + "/"); - Map cases = getTestCases(newBase, true); + Map cases = getTestCases(newBase, true, modelFilename); Preconditions.checkState(cases.size() == 1, "Expected 1 test case, got %s", cases.size()); return cases.get(cases.keySet().iterator().next()); } - public static Map getTestCases(String baseDir, boolean singleTest) throws Exception { + public static Map getTestCases(String baseDir, boolean singleTest, String modelFilename) throws Exception { baseDir = baseDir.replaceAll("\\\\", "/"); if (!baseDir.endsWith("/")) baseDir += "/"; long start = System.currentTimeMillis(); -// String baseDir = "tf_graphs/examples/"; List l = ResourceUtils.listClassPathFiles(baseDir, true, false); long end = System.currentTimeMillis(); - Set listAsSet = new HashSet<>(l); - - - Set modelSet = new HashSet<>(); + Set set = new HashSet<>(l); Map map = new HashMap<>(); @@ -81,9 +77,8 @@ public static Map getTestCases(String baseDir, boolean singleT } else if (idx > 0) { name = sub.substring(0, idx); modelDir = baseDir + sub.substring(0, idx + 1); - String expModel = modelDir + TFGraphTestAllSameDiff.MODEL_FILENAME; -// while(!Resources.exists(expModel) && idx > 0){ - while (!listAsSet.contains(expModel) && idx > 0) { + String expModel = modelDir + modelFilename; + while (!set.contains(expModel) && idx > 0) { //Due to a mixing of directories and variable names - we //For example we might have "X/frozen_model.pb" //And then also "X/something/or/other.csv @@ -97,21 +92,15 @@ public static Map getTestCases(String baseDir, boolean singleT } sub = sub.substring(0, idx); - expModel = baseDir + sub + "/" + TFGraphTestAllSameDiff.MODEL_FILENAME; + expModel = baseDir + sub + "/" + modelFilename; modelDir = baseDir + sub + "/"; name = sub; } -// name = n; } -// if(modelDir == null) -// continue; if(badTest || modelDir == null) continue; - - modelSet.add(name); - TestCase tc = map.get(name); if (tc == null) { tc = new TestCase(name, null, null, null); @@ -156,7 +145,6 @@ public static Map getTestCases(String baseDir, boolean singleT } } } -// System.out.println(sub); } long end2 = System.currentTimeMillis(); From 796d78634fd98a062a6fa4f3c4da8c4dcba43dc8 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 12 May 2020 18:39:24 +1000 Subject: [PATCH 7/8] Clean up Signed-off-by: Alex Black --- .../test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java | 9 --------- 1 file changed, 9 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java index 2dbd34de9a15..dc721572436a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphUtil.java @@ -53,16 +53,11 @@ public static Map getTestCases(String baseDir, boolean singleT if (!baseDir.endsWith("/")) baseDir += "/"; - long start = System.currentTimeMillis(); List l = ResourceUtils.listClassPathFiles(baseDir, true, false); - long end = System.currentTimeMillis(); Set set = new HashSet<>(l); - - Map map = new HashMap<>(); - long start2 = System.currentTimeMillis(); for (String s : l) { String sub = s.substring(baseDir.length()); @@ -146,10 +141,6 @@ public static Map getTestCases(String baseDir, boolean singleT } } } - long end2 = System.currentTimeMillis(); - - System.out.println("List duration: " + (end - start)); - System.out.println("Process duration: " + (end2 - start2)); return map; } From 7619ec50a2c08f77f54b043d2b10c0504882889e Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 15 May 2020 21:50:34 +1000 Subject: [PATCH 8/8] Update Signed-off-by: Alex Black --- .../java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java index bf6628ed2eb3..706f706d6231 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TF2ImportTestsSameDiff.java @@ -54,7 +54,7 @@ protected void starting(Description description){ private String baseDir; private static final TFGraphTestAllHelper.ExecuteWith EXECUTE_WITH = TFGraphTestAllHelper.ExecuteWith.SAMEDIFF; - public static final String[] BASE_DIRS = new String[]{"tf_graphs/examples2.1"}; //Add directories for any other TensorFlow versions here + public static final String[] BASE_DIRS = new String[]{"tf_graphs/examples2.2.0"}; //Add directories for any other TensorFlow versions here public static final String MODEL_FILENAME = "frozen_model.pb"; public static final String[] IGNORE_REGEXES = new String[]{