diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java index afd1286692dc..1f7e938b6248 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java @@ -45,7 +45,7 @@ public int convertInt(Writable writable) { } public int convertInt(String o) { - return Integer.parseInt(o); + return (int) Double.parseDouble(o); } public double convertDouble(Writable writable) { diff --git a/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java b/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java index 48f9474d5fdd..f76917063364 100644 --- a/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java +++ b/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java @@ -725,7 +725,6 @@ private static List createFieldVectors(BufferAllocator bufferAlloca case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break; case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break; default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i)); - } } @@ -795,6 +794,7 @@ public static void setValue(ColumnType columnType,FieldVector fieldVector,Object long timeSet = TypeConversion.getInstance().convertLong(value); setLongInTime(fieldVector, row, timeSet); break; + case Bytes: case NDArray: NDArrayWritable arr = (NDArrayWritable) value; VarBinaryVector nd4jArrayVector = (VarBinaryVector) fieldVector; @@ -802,8 +802,13 @@ public static void setValue(ColumnType columnType,FieldVector fieldVector,Object //for proper offsets ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get()); nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity()); + case Boolean: + BitVector bitVector = (BitVector) fieldVector; + if(value instanceof Boolean) + bitVector.set(row, (boolean) value ? 1 : 0); + else + bitVector.set(row, ((BooleanWritable) value).get() ? 1 : 0); break; - } }catch(Exception e) { log.warn("Unable to set value at row " + row); @@ -1216,6 +1221,7 @@ public static Writable fromEntry(int item,FieldVector from,ColumnType columnType case Time: //TODO: need to look at closer return new LongWritable(getLongFromFieldVector(item,from)); + case Bytes: case NDArray: VarBinaryVector valueVector = (VarBinaryVector) from; byte[] bytes = valueVector.get(item); diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java index 183e43c189de..8181a4d63eb9 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java @@ -223,7 +223,6 @@ private PythonVariables getPyInputsFromWritables(List writables) { } else { ret.addInt(name, ((IntWritable) w).get()); } - break; case FLOAT: if (w instanceof DoubleWritable) { @@ -238,6 +237,9 @@ private PythonVariables getPyInputsFromWritables(List writables) { case NDARRAY: ret.addNDArray(name, ((NDArrayWritable) w).get()); break; + case BOOL: + ret.addBool(name, ((BooleanWritable) w).get()); + break; default: throw new RuntimeException("Unsupported input type:" + pyType); } @@ -270,6 +272,9 @@ private List getWritablesFromPyOutputs(PythonVariables pyOuts) { NumpyArray arr = pyOuts.getNDArrayValue(name); schemaBuilder.addColumnNDArray(name, arr.getShape()); break; + case BOOL: + schemaBuilder.addColumnBoolean(name); + break; default: throw new IllegalStateException("Unable to support type " + pyType.name()); } @@ -318,6 +323,9 @@ private List getWritablesFromPyOutputs(PythonVariables pyOuts) { throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!"); } break; + case BOOL: + out.add(new BooleanWritable(pyOuts.getBooleanValue(name))); + break; default: throw new IllegalStateException("Unable to support type " + pyType.name()); } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java index 18fdaf27a7e0..293260594229 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java @@ -126,6 +126,9 @@ public static PythonVariables schemaToPythonVariables(Schema schema) throws Exce case NDArray: pyVars.addNDArray(colName); break; + case Boolean: + pyVars.addBool(colName); + break; default: throw new Exception("Unsupported python input type: " + colType.toString()); }