-
Notifications
You must be signed in to change notification settings - Fork 58
Expand file tree
/
Copy pathRNTensorflowInference.java
More file actions
143 lines (121 loc) · 4.95 KB
/
RNTensorflowInference.java
File metadata and controls
143 lines (121 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package com.rntensorflow;
import com.facebook.react.bridge.*;
import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.contrib.android.RunStats;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static com.rntensorflow.converter.ArrayConverter.*;
import static com.rntensorflow.converter.ArrayConverter.byteArrayToBoolReadableArray;
import static com.rntensorflow.converter.ArrayConverter.intArrayToReadableArray;
public class RNTensorflowInference {
private final ReactContext reactContext;
private final TfContext tfContext;
public RNTensorflowInference(ReactContext reactContext, TfContext tfContext) {
this.reactContext = reactContext;
this.tfContext = tfContext;
}
public static RNTensorflowInference init(ReactContext reactContext, String model) throws IOException {
loadNativeTf();
TfContext context = createContext(reactContext, model);
return new RNTensorflowInference(reactContext, context);
}
private static void loadNativeTf() {
try {
new RunStats();
} catch (UnsatisfiedLinkError ule) {
System.loadLibrary("tensorflow_inference");
}
}
private static TfContext createContext(ReactContext reactContext, String model) throws IOException {
byte[] b = new ResourceManager(reactContext).loadResource(model);
Graph graph = new Graph();
graph.importGraphDef(b);
Session session = new Session(graph);
Session.Runner runner = session.runner();
return new TfContext(session, runner, graph);
}
public void feed(String inputName, Tensor tensor) {
tfContext.runner.feed(inputName, tensor);
}
public void run(String[] outputNames, boolean enableStats) {
if(tfContext != null) {
for (String outputName : outputNames) {
tfContext.runner.fetch(outputName);
}
List<Tensor<?>> tensors = tfContext.runner.run();
tfContext.outputTensors.clear();
for (int i = 0; i < outputNames.length; i++) {
tfContext.outputTensors.put(outputNames[i], tensors.get(i));
}
} else {
throw new IllegalStateException("Could not find inference for id");
}
}
public ReadableArray fetch(String outputName) {
Tensor tensor = tfContext.outputTensors.get(outputName);
int numElements = tensor.numElements();
if(tensor.dataType() == DataType.DOUBLE) {
DoubleBuffer dst = DoubleBuffer.allocate(numElements);
tensor.writeTo(dst);
return doubleArrayToReadableArray(dst.array());
} else if(tensor.dataType() == DataType.FLOAT) {
FloatBuffer dst = FloatBuffer.allocate(numElements);
tensor.writeTo(dst);
return floatArrayToReadableArray(dst.array());
} else if(tensor.dataType() == DataType.INT32) {
IntBuffer dst = IntBuffer.allocate(numElements);
tensor.writeTo(dst);
return intArrayToReadableArray(dst.array());
} else if(tensor.dataType() == DataType.INT64) {
DoubleBuffer dst = DoubleBuffer.allocate(numElements);
tensor.writeTo(dst);
return doubleArrayToReadableArray(dst.array());
} else if(tensor.dataType() == DataType.UINT8) {
IntBuffer dst = IntBuffer.allocate(numElements);
tensor.writeTo(dst);
return intArrayToReadableArray(dst.array());
} else if(tensor.dataType() == DataType.BOOL) {
ByteBuffer dst = ByteBuffer.allocate(numElements);
tensor.writeTo(dst);
return byteArrayToBoolReadableArray(dst.array());
} else {
throw new IllegalArgumentException("Data type is not supported");
}
}
public void close() {
if(tfContext != null) {
tfContext.session.close();
tfContext.outputTensors.clear();
} else {
throw new IllegalStateException("Could not find inference for id");
}
}
public TfContext getTfContext() {
return tfContext;
}
public static class TfContext {
final Session session;
Session.Runner runner;
final Graph graph;
private final Map<String, Tensor> outputTensors;
TfContext(Session session, Session.Runner runner, Graph graph) {
this.session = session;
this.runner = runner;
this.graph = graph;
outputTensors = new HashMap<>();
}
public void reset() {
runner = session.runner();
outputTensors.clear();
}
}
}