From 918c6cb38d2685baa7eeb33502c4d96c73583bee Mon Sep 17 00:00:00 2001 From: shamsulazeem Date: Wed, 22 Jan 2020 21:57:32 +0500 Subject: [PATCH 1/2] Add support for boolean types in arrow records and ability to cast from float, double to int for TypeConversion --- .../transform/schema/conversion/TypeConversion.java | 2 +- .../main/java/org/datavec/arrow/ArrowConverter.java | 8 ++++++-- .../main/java/org/datavec/python/PythonTransform.java | 10 +++++++++- .../src/main/java/org/datavec/python/PythonUtils.java | 3 +++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java index afd1286692dc..1f7e938b6248 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java @@ -45,7 +45,7 @@ public int convertInt(Writable writable) { } public int convertInt(String o) { - return Integer.parseInt(o); + return (int) Double.parseDouble(o); } public double convertDouble(Writable writable) { diff --git a/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java b/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java index 48f9474d5fdd..f6647535762f 100644 --- a/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java +++ b/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java @@ -725,7 +725,6 @@ private static List createFieldVectors(BufferAllocator bufferAlloca case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break; case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break; default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i)); - } } @@ -802,8 +801,13 @@ public static void setValue(ColumnType columnType,FieldVector fieldVector,Object //for proper offsets ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get()); nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity()); + case Boolean: + BitVector bitVector = (BitVector) fieldVector; + if(value instanceof Boolean) + bitVector.set(row, (boolean) value ? 1 : 0); + else + bitVector.set(row, ((BooleanWritable) value).get() ? 1 : 0); break; - } }catch(Exception e) { log.warn("Unable to set value at row " + row); diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java index 183e43c189de..8181a4d63eb9 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java @@ -223,7 +223,6 @@ private PythonVariables getPyInputsFromWritables(List writables) { } else { ret.addInt(name, ((IntWritable) w).get()); } - break; case FLOAT: if (w instanceof DoubleWritable) { @@ -238,6 +237,9 @@ private PythonVariables getPyInputsFromWritables(List writables) { case NDARRAY: ret.addNDArray(name, ((NDArrayWritable) w).get()); break; + case BOOL: + ret.addBool(name, ((BooleanWritable) w).get()); + break; default: throw new RuntimeException("Unsupported input type:" + pyType); } @@ -270,6 +272,9 @@ private List getWritablesFromPyOutputs(PythonVariables pyOuts) { NumpyArray arr = pyOuts.getNDArrayValue(name); schemaBuilder.addColumnNDArray(name, arr.getShape()); break; + case BOOL: + schemaBuilder.addColumnBoolean(name); + break; default: throw new IllegalStateException("Unable to support type " + pyType.name()); } @@ -318,6 +323,9 @@ private List getWritablesFromPyOutputs(PythonVariables pyOuts) { throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!"); } break; + case BOOL: + out.add(new BooleanWritable(pyOuts.getBooleanValue(name))); + break; default: throw new IllegalStateException("Unable to support type " + pyType.name()); } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java index 18fdaf27a7e0..293260594229 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java @@ -126,6 +126,9 @@ public static PythonVariables schemaToPythonVariables(Schema schema) throws Exce case NDArray: pyVars.addNDArray(colName); break; + case Boolean: + pyVars.addBool(colName); + break; default: throw new Exception("Unsupported python input type: " + colType.toString()); } From 6f97e6a27649e15a587152d8a9abd12c7b48bd54 Mon Sep 17 00:00:00 2001 From: shamsulazeem Date: Fri, 24 Jan 2020 21:10:06 +0500 Subject: [PATCH 2/2] Reading NDArrays from Bytes for arrow records --- .../src/main/java/org/datavec/arrow/ArrowConverter.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java b/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java index f6647535762f..f76917063364 100644 --- a/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java +++ b/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java @@ -794,6 +794,7 @@ public static void setValue(ColumnType columnType,FieldVector fieldVector,Object long timeSet = TypeConversion.getInstance().convertLong(value); setLongInTime(fieldVector, row, timeSet); break; + case Bytes: case NDArray: NDArrayWritable arr = (NDArrayWritable) value; VarBinaryVector nd4jArrayVector = (VarBinaryVector) fieldVector; @@ -1220,6 +1221,7 @@ public static Writable fromEntry(int item,FieldVector from,ColumnType columnType case Time: //TODO: need to look at closer return new LongWritable(getLongFromFieldVector(item,from)); + case Bytes: case NDArray: VarBinaryVector valueVector = (VarBinaryVector) from; byte[] bytes = valueVector.get(item);