Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public int convertInt(Writable writable) {
}

public int convertInt(String o) {
return Integer.parseInt(o);
return (int) Double.parseDouble(o);
}

public double convertDouble(Writable writable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,6 @@ private static List<FieldVector> createFieldVectors(BufferAllocator bufferAlloca
case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i));

}
}

Expand Down Expand Up @@ -795,15 +794,21 @@ public static void setValue(ColumnType columnType,FieldVector fieldVector,Object
long timeSet = TypeConversion.getInstance().convertLong(value);
setLongInTime(fieldVector, row, timeSet);
break;
case Bytes:
case NDArray:
NDArrayWritable arr = (NDArrayWritable) value;
VarBinaryVector nd4jArrayVector = (VarBinaryVector) fieldVector;
//slice the databuffer to use only the needed portion of the buffer
//for proper offsets
ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get());
nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity());
case Boolean:
BitVector bitVector = (BitVector) fieldVector;
if(value instanceof Boolean)
bitVector.set(row, (boolean) value ? 1 : 0);
else
bitVector.set(row, ((BooleanWritable) value).get() ? 1 : 0);
break;

}
}catch(Exception e) {
log.warn("Unable to set value at row " + row);
Expand Down Expand Up @@ -1216,6 +1221,7 @@ public static Writable fromEntry(int item,FieldVector from,ColumnType columnType
case Time:
//TODO: need to look at closer
return new LongWritable(getLongFromFieldVector(item,from));
case Bytes:
case NDArray:
VarBinaryVector valueVector = (VarBinaryVector) from;
byte[] bytes = valueVector.get(item);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
} else {
ret.addInt(name, ((IntWritable) w).get());
}

break;
case FLOAT:
if (w instanceof DoubleWritable) {
Expand All @@ -238,6 +237,9 @@ private PythonVariables getPyInputsFromWritables(List<Writable> writables) {
case NDARRAY:
ret.addNDArray(name, ((NDArrayWritable) w).get());
break;
case BOOL:
ret.addBool(name, ((BooleanWritable) w).get());
break;
default:
throw new RuntimeException("Unsupported input type:" + pyType);
}
Expand Down Expand Up @@ -270,6 +272,9 @@ private List<Writable> getWritablesFromPyOutputs(PythonVariables pyOuts) {
NumpyArray arr = pyOuts.getNDArrayValue(name);
schemaBuilder.addColumnNDArray(name, arr.getShape());
break;
case BOOL:
schemaBuilder.addColumnBoolean(name);
break;
default:
throw new IllegalStateException("Unable to support type " + pyType.name());
}
Expand Down Expand Up @@ -318,6 +323,9 @@ private List<Writable> getWritablesFromPyOutputs(PythonVariables pyOuts) {
throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
}
break;
case BOOL:
out.add(new BooleanWritable(pyOuts.getBooleanValue(name)));
break;
default:
throw new IllegalStateException("Unable to support type " + pyType.name());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ public static PythonVariables schemaToPythonVariables(Schema schema) throws Exce
case NDArray:
pyVars.addNDArray(colName);
break;
case Boolean:
pyVars.addBool(colName);
break;
default:
throw new Exception("Unsupported python input type: " + colType.toString());
}
Expand Down