-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkublas.cpp
More file actions
171 lines (100 loc) · 4.98 KB
/
kublas.cpp
File metadata and controls
171 lines (100 loc) · 4.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#include "kublas.h"
#include <jni.h>
#include <cublas_v2.h>
JNIEXPORT jlong JNICALL Java_kuda_kublas_Kublas_create(JNIEnv* env, jobject obj) {
cublasHandle_t handle;
cublasStatus_t cublasStatus = cublasCreate(&handle);
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
return cublasStatus;
}
return (jlong) handle;
}
JNIEXPORT jint JNICALL Java_kuda_kublas_Kublas_destroy(JNIEnv* env, jobject obj, jlong handle) {
cublasHandle_t cublasHandle = reinterpret_cast<cublasHandle_t>(handle);
cublasStatus_t cublasStatus = cublasDestroy(cublasHandle);
return cublasStatus;
}
JNIEXPORT jint JNICALL Java_kuda_kublas_Kublas_getVersion(JNIEnv* env, jobject obj, jlong handle) {
int version;
cublasHandle_t cublasHandle = reinterpret_cast<cublasHandle_t>(handle);
cublasStatus_t cublasStatus = cublasGetVersion(cublasHandle, &version);
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
return cublasStatus;
}
return version;
}
JNIEXPORT jint JNICALL Java_kuda_kublas_Kublas_getProperty(JNIEnv* env, jobject obj, jint type) {
int version;
libraryPropertyType_t libraryPropertyType = static_cast<libraryPropertyType_t>(type);
cublasStatus_t cublasStatus = cublasGetProperty(libraryPropertyType, &version);
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
return cublasStatus;
}
return version;
}
JNIEXPORT jstring JNICALL Java_kuda_kublas_Kublas_getStatusName(JNIEnv* env, jobject obj, jint status) {
return env->NewStringUTF(cublasGetStatusName(static_cast<cublasStatus_t>(status)));
}
JNIEXPORT jstring JNICALL Java_kuda_kublas_Kublas_getStatusString(JNIEnv* env, jobject obj, jint status) {
return env->NewStringUTF(cublasGetStatusString(static_cast<cublasStatus_t>(status)));
}
JNIEXPORT jint JNICALL Java_kuda_kublas_Kublas_setStream(JNIEnv* env, jobject obj, jlong handle, jlong streamId) {
cublasHandle_t cublasHandle = reinterpret_cast<cublasHandle_t>(handle);
cudaStream_t cudaStream = reinterpret_cast<cudaStream_t>(streamId);
cublasStatus_t cublasStatus = cublasSetStream(cublasHandle, cudaStream);
return cublasStatus;
}
JNIEXPORT jint JNICALL Java_kuda_kublas_Kublas_setWorkspace(JNIEnv* env, jobject obj, jlong handle, jsize workspaceSizeInBytes) {
void* workspace{};
cublasHandle_t cublasHandle = reinterpret_cast<cublasHandle_t>(handle);
cublasStatus_t cublasStatus = cublasSetWorkspace(cublasHandle, workspace, (size_t)workspaceSizeInBytes);
return cublasStatus;
}
JNIEXPORT jlong JNICALL Java_kuda_kublas_Kublas_getStream(JNIEnv* env, jobject obj, jlong handle) {
cudaStream_t cudaStream;
cublasHandle_t cublasHandle = reinterpret_cast<cublasHandle_t>(handle);
cublasStatus_t cublasStatus = cublasGetStream(cublasHandle, &cudaStream);
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
return cublasStatus;
}
return (jlong)cudaStream;
}
JNIEXPORT jint JNICALL Java_kuda_kublas_Kublas_getPointerMode(JNIEnv* env, jobject obj, jlong handle) {
//initialize...
cublasPointerMode_t cublasPointerMode = CUBLAS_POINTER_MODE_HOST;
cublasHandle_t cublasHandle = reinterpret_cast<cublasHandle_t>(handle);
cublasStatus_t cublasStatus = cublasGetPointerMode(cublasHandle, &cublasPointerMode);
if (cublasStatus != CUBLAS_STATUS_SUCCESS) {
return cublasStatus;
}
return static_cast<int>(cublasPointerMode);
}
JNIEXPORT jint JNICALL Java_kuda_kublas_Kublas_setPointerMode(JNIEnv* env, jobject obj, jlong handle, jint mode) {
cublasHandle_t cublasHandle = reinterpret_cast<cublasHandle_t>(handle);
cublasStatus_t cublasStatus = cublasSetPointerMode(cublasHandle, static_cast<cublasPointerMode_t>(mode));
return cublasStatus;
}
JNIEXPORT jint JNICALL Java_kuda_kublas_Kublas_setVector(JNIEnv* env, jobject obj, jint n, jint elemSize, jlong x, jint incx, jlong y, jint incy) {
const void* cublasX = reinterpret_cast<void*>(x);
void* cublasY = reinterpret_cast<void*>(y);
cublasStatus_t cublasStatus = cublasSetVector(n, elemSize, cublasX, incx, cublasY, incy);
return cublasStatus;
}
JNIEXPORT jint JNICALL Java_kuda_kublas_Kublas_getVector(JNIEnv* env, jobject obj, jint n, jint elemSize, jlong x, jint incx, jlong y, jint incy) {
const void* cublasX = reinterpret_cast<void*>(x);
void* cublasY = reinterpret_cast<void*>(y);
cublasStatus_t cublasStatus = cublasGetVector(n, elemSize, cublasX, incx, cublasY, incy);
return cublasStatus;
}
JNIEXPORT jint JNICALL Java_kuda_kublas_Kublas_setMatrix(JNIEnv* env, jobject obj, jint rows, jint cols, jint elemSize, jlong A, jint lda, jlong B, jint ldb) {
const void* cublasA = reinterpret_cast<void*>(A);
void* cublasB = reinterpret_cast<void*>(B);
cublasStatus_t cublasStatus = cublasSetMatrix(rows, cols, elemSize, cublasA, lda, cublasB, ldb);
return cublasStatus;
}
JNIEXPORT jint JNICALL Java_kuda_kublas_Kublas_getMatrix(JNIEnv* env, jobject obj, jint rows, jint cols, jint elemSize, jlong A, jint lda, jlong B, jint ldb) {
const void* cublasA = reinterpret_cast<void*>(A);
void* cublasB = reinterpret_cast<void*>(B);
cublasStatus_t cublasStatus = cublasGetMatrix(rows, cols, elemSize, cublasA, lda, cublasB, ldb);
return cublasStatus;
}