-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathTsinghuaMixQPlugin.cpp
More file actions
executable file
·963 lines (729 loc) · 28.4 KB
/
TsinghuaMixQPlugin.cpp
File metadata and controls
executable file
·963 lines (729 loc) · 28.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "TsinghuaMixQPlugin.h"
#include <numeric>
#include <cublas_v2.h>
#include <cublasLt.h>
#include <stdio.h>
#define CUBLAS_WORKSPACE_SIZE 33554432
#define CUBLAS_CHECK(call) \
do \
{ \
cublasStatus_t status = call; \
if (status != CUBLAS_STATUS_SUCCESS) \
{ \
fprintf(stderr, "cuBLAS error at %s:%d : %d\n", __FILE__, __LINE__, status); \
exit(EXIT_FAILURE); \
} \
} while (0)
void gemm(
const int8_t * mat1,
const int8_t * mat2, int *mat3, int m, int n, int k,cublasHandle_t handle, cudaStream_t stream) {
static int64_t _beta = 0;
static int64_t _alpha = 1;
auto beta_ptr = (void*)&_beta;
auto alpha_ptr = (void*)&_alpha;
auto input_ptr = (void*)mat3;
auto mat1_ptr = (void*)mat1;
auto mat2_ptr = (void*)mat2;
//cublasHandle_t handle;
(cublasGemmEx(
handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
n,
m,
k,
alpha_ptr,
mat2_ptr,
CUDA_R_8I,
k,
mat1_ptr,
CUDA_R_8I,
k,
beta_ptr,
input_ptr,
CUDA_R_32I,
n,
CUBLAS_COMPUTE_32I,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void gemmfp16add(
const half * mat1,
const half * mat2, half *mat3, int m, int n, int k,cublasHandle_t handle, cudaStream_t stream) {
static float _beta = 1.0;
static float _alpha = 1.0;
auto beta_ptr = (void*)&_beta;
auto alpha_ptr = (void*)&_alpha;
auto input_ptr = (void*)mat3;
auto mat1_ptr = (void*)mat1;
auto mat2_ptr = (void*)mat2;
(cublasGemmEx(
handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
n,
m,
k,
alpha_ptr,
mat2_ptr,
CUDA_R_16F,
k,
mat1_ptr,
CUDA_R_16F,
k,
beta_ptr,
input_ptr,
CUDA_R_16F,
n,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void gemmfp16(
const half * mat1,
const half * mat2, half *mat3, int m, int n, int k, cublasHandle_t handle, cudaStream_t stream) {
static float _beta = 0.0;
static float _alpha = 1.0;
auto beta_ptr = (void*)&_beta;
auto alpha_ptr = (void*)&_alpha;
auto input_ptr = (void*)mat3;
auto mat1_ptr = (void*)mat1;
auto mat2_ptr = (void*)mat2;
(cublasGemmEx(
handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
n,
m,
k,
alpha_ptr,
mat2_ptr,
CUDA_R_16F,
k,
mat1_ptr,
CUDA_R_16F,
k,
beta_ptr,
input_ptr,
CUDA_R_16F,
n,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
// Import a generated header to use generated triton kernels.
extern "C"
{
// #include "aot/fmha_kernel_fp16.h"
// #include "aot/fmha_kernel_fp32.h"
}
#include "kernel/int8FusedDequantizeCUDA.h"
#include "weightonlykernel/fpA_intB_gemm_wrapper.h"
#include <cstring>
#include <cuda_fp16.h>
#include <iostream>
#include <string>
using namespace nvinfer1;
using openai_triton::plugin::MixQPluginCreator;
using openai_triton::plugin::MixQPlugin;
static char const* TRITON_FLASH_ATTENTION_PLUGIN_VERSION{"1"};
static char const* TRITON_FLASH_ATTENTION_PLUGIN_NAME{"MixQ"};
PluginFieldCollection MixQPluginCreator::mFC{};
std::vector<PluginField> MixQPluginCreator::mPluginAttributes;
namespace openai_triton::plugin
{
// Write values into buffer
template <typename T>
void writeArg(char*& buffer, T const& val)
{
std::memcpy(buffer, &val, sizeof(T));
buffer += sizeof(T);
}
// Read values from buffer
template <typename T>
void readArg(char const*& buffer, T& val)
{
std::memcpy(&val, buffer, sizeof(T));
buffer += sizeof(T);
}
std::uintptr_t constexpr kCudaMemAlign = 128;
int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize)
{
uintptr_t addr = (uintptr_t) ptr;
addr += previousWorkspaceSize;
if (addr % kCudaMemAlign)
{
addr += kCudaMemAlign - addr % kCudaMemAlign;
}
return (int8_t*) addr;
}
MixQPlugin::MixQPlugin(
int m, int n, int k)
: mm(m)
, mn(n)
, mk(k)
{
}
// Parameterized constructor
MixQPlugin::MixQPlugin(void const* data, size_t length)
{
char const *d = reinterpret_cast<char const*>(data), *a = d;
readArg(d, mm);
readArg(d, mn);
readArg(d, mk);
}
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* MixQPlugin::clone() const noexcept
{
auto* plugin = new MixQPlugin(*this);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
nvinfer1::DimsExprs MixQPlugin::getOutputDimensions(
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
// Output shape.
// output tensor [batchSize, seqLen, mNumHeads, head_size]
assert(outputIndex == 0);
int const nbDimsA = inputs[0].nbDims;
int const nbDimsB = inputs[1].nbDims;
DimsExprs ret;
ret.nbDims = nbDimsA;
for (int ii = 0; ii < nbDimsA - 1; ++ii)
{
ret.d[ii] = inputs[0].d[ii];
}
ret.d[nbDimsA - 1] = exprBuilder.constant(inputs[1].d[0]->getConstantValue());
return ret;
}
bool MixQPlugin::supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
{
//printf(" pos =%d --------------------",pos);
switch (pos)
{
case 0:
// activation
return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
case 1:
// weights
// Weights stored in checkpoint must have int8 type
return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
case 2:
// scales channels
return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
case 3:
// fp weight
return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
case 4:
// ind
// std:: cout << inOut[pos].type << std::endl;
// std:: cout << nvinfer1::DataType::kINT32 << std::endl;
return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
// 增加3个
case 5:
return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
case 6:
return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
// case 7:
// return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
case 7:
// out
return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
default:
// Never should be here
assert(false);
return false;
}
}
void MixQPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
{
auto const minM = std::accumulate(in[0].min.d, in[0].min.d + in[0].min.nbDims - 1, 1, std::multiplies<int>());
auto const maxM = std::accumulate(in[0].max.d, in[0].max.d + in[0].max.nbDims - 1, 1, std::multiplies<int>());
int const maxK = in[0].max.d[in[0].max.nbDims - 1] ;
int const maxN = in[1].max.d[0];
int const minK = in[0].min.d[in[0].min.nbDims - 1] ;
int const minN = in[1].min.d[0];
assert(minN == maxN );
assert(minK == maxK );
// int8 quant + scale factor + fp16 weight + grand
m_workspaceMaxSize = maxM * maxK * sizeof(int8_t) + maxM * sizeof(half)
+ maxK * maxN * sizeof(half) ;
int m_workspaceMaxSize2 = maxM * maxN * sizeof(half) * 8; // for awq
if ( m_workspaceMaxSize2 > m_workspaceMaxSize)
m_workspaceMaxSize = m_workspaceMaxSize2;
}
size_t MixQPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
{
// Set workspace size if needed. In this example, we need for L and m buffers.
// int const numBuffers = 1;
// size_t workspaces[numBuffers];
// workspaces[0] = sizeof(half) * mm * mn;
// size_t total = 0;
// for (int i = 0; i < numBuffers; i++)
// {
// total += workspaces[i];
// if (workspaces[i] % kCudaMemAlign)
// {
// total += kCudaMemAlign - (workspaces[i] % kCudaMemAlign);
// }
// }
// printf("total is %d %d %d", mm, mn, workspaces[0]);
// return total * 2;
//printf("m_workspaceMaxSize = %d \n",m_workspaceMaxSize);
int workspace = m_workspaceMaxSize;
if (workspace <= 0)
workspace = CUBLAS_WORKSPACE_SIZE;
return workspace;
}
int MixQPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs,
void* workspace,
cudaStream_t stream)
{
int M = 1;
for (int ii = 0; ii < inputDesc[0].dims.nbDims - 1; ++ii)
{
M *= inputDesc[0].dims.d[ii];
}
int K = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
int N = inputDesc[1].dims.d[0];
int res = 0;
half * Out = reinterpret_cast<half *>(outputs[0]);
half* actPtr = reinterpret_cast<half*>(workspace);
int8_t* int8_out = reinterpret_cast<int8_t*>(nextWorkspacePtr(reinterpret_cast<int8_t*>(actPtr), 0));
const size_t bufSize_int8_out = sizeof(int8_t) * (M) * K ;
half* scale_a = reinterpret_cast<half*>(nextWorkspacePtr(reinterpret_cast<int8_t*>(int8_out),
bufSize_int8_out));
half* fp_activation = reinterpret_cast<half*>(nextWorkspacePtr(reinterpret_cast<int8_t*>(scale_a),
sizeof(half) * (M) ));
int* int_tmp = reinterpret_cast<int*>(nextWorkspacePtr(reinterpret_cast<int8_t*>(scale_a),
sizeof(half) * (M) ));
// half* grand = reinterpret_cast<half*>(nextWorkspacePtr(reinterpret_cast<int8_t*>(fp_weight),
// sizeof(half) * (N) * K ));
const half * A = reinterpret_cast<half const*>(inputs[0]);
const int8_t * W = reinterpret_cast<int8_t const*>(inputs[1]);
const half * scale_b = reinterpret_cast<half const* >(inputs[2]);
// outliers
const half * fp_weight = reinterpret_cast<half const*>(inputs[3]);
const int * ind = reinterpret_cast<int const*>(inputs[4]);
const int8_t * q_weight = reinterpret_cast<int8_t const*>(inputs[5]); // int weight // 采用fp16存
const half * scaling_factors = reinterpret_cast<half const*>(inputs[6]); // scaling factors
// std::vector<int> out(128);
// // Launch a cuda kernel generated by Triton AoT.
//
// FILE *fp;
// half *tmp = (half *) malloc( sizeof(half) * M * N);
// cudaMemcpy(tmp, actPtr, sizeof(half) * M * N, cudaMemcpyDeviceToHost);
// fp = fopen ("init.csv", "w");
// for (int i =0; i < 10; ++i){
// for (int j = 0; j < 20; ++j){
// fprintf (fp, "%4f\t", float( tmp[j + i * N] ));
// }
// fprintf (fp, "\n");
// }
// fclose(fp);
if (M > 4){
// prefill
//printf("M N K is %d %d %d\n",M,N,K);
// // print_half(fp_weight,10);
// // exit(0);
// check outliers of A
// FILE *fp;
// half *tmp = (half *) malloc( sizeof(half) * M * K);
// cudaMemcpy(tmp, A, sizeof(half) * M * K, cudaMemcpyDeviceToHost);
// fp = fopen ("A1.csv", "w");
// for (int i =0; i < M; ++i){
// for (int j = 0; j < K; ++j){
// fprintf (fp, "%4f\t", float(tmp[j + i * K]));
// }
// fprintf (fp, "\n");
// }
// fclose(fp);
// ExtractOutliersAndSetToZeros(M, K, A, fp_activation, ind, num_ind, stream);
// fp = fopen ("A2.csv", "w");
// for (int i =0; i < M; ++i){
// for (int j = 0; j < K; ++j){
// fprintf (fp, "%4f\t", float( tmp[j + i * K] ));
// }
// fprintf (fp, "\n");
// }
// fclose(fp);
// exit(0);
const int num_ind = 128;
ExtractOutliersAndSetToZeros(M, K, A, fp_activation, ind, num_ind, stream);
cublasSetStream(handle,stream);
gemmfp16(fp_activation,fp_weight,Out, M, N, num_ind, handle, stream);
int8quant(M, K, A, int8_out, scale_a, stream);
//cuda
//gemm(int8_out, W, int_tmp, M, N, K, stream);
//dequantizationCUDA(Out, int_tmp, scale_a, scale_b, M, N, stream);
// half *tmp = (half *) malloc( sizeof(half) * M * K);
// cudaMemcpy(tmp, A, sizeof(half) * M * K, cudaMemcpyDeviceToHost);
int8FusedDequantizeCUDA(int8_out, W, scale_a,
scale_b, Out, Out, M, N, K,
reinterpret_cast<char*>(workspace),
stream);
// FILE *fp;
// half *tmp = (half *) malloc( sizeof(half) * M * N);
// cudaMemcpy(tmp, Out, sizeof(half) * M * N, cudaMemcpyDeviceToHost);
// fp = fopen ("grand.csv", "w");
// for (int i =0; i < 10; ++i){
// for (int j = 0; j < 20; ++j){
// fprintf (fp, "%4f\t", float( tmp[j + i * N] ));
// }
// fprintf (fp, "\n");
// }
// fclose(fp);
// exit(0);
// bool flag = 0;
// for (int i =0; i < M; ++i){
// for (int j = 0; j < K; ++j){
// float aaa = float(tmp[j + i * K]);
// if (aaa > 6.0){
// printf("not!!! outliers in %d\n",j);
// flag = 1;
// }
// }
// }
// if (flag){
// int8dequant(N, K, fp_activation, W, scale_b, stream);
// gemmfp16add(A,fp_activation,Out, M, N, K, stream);
// }
// else{
// int8quant(M, K, A, int8_out, scale_a, stream);
// gemm(int8_out, W, int_tmp, M, N, K, stream);
// dequantizationCUDA(Out, int_tmp, scale_a, scale_b, M, N, stream);
// // int8FusedDequantizeCUDA(int8_out, W, scale_a,
// // scale_b,
// // Out, Out, M, N, K, stream);
// }
// int8quant(M, K, A, int8_out, scale_a, stream);
// int8FusedDequantizeCUDA(int8_out, W, scale_a,
// scale_b,
// Out, Out, M, N, K, stream);
// float alpha = -1.0;
// half result = -1.0;
// half norm = -1.0;
// cublasHandle_t handle;
// cublasCreate(&handle);
// cublasSetStream(handle,stream);
// CUBLAS_CHECK( cublasNrm2Ex(handle, M * N, Out, CUDA_R_16F, 1,
// &norm, CUDA_R_16F, CUDA_R_32F));
// CUBLAS_CHECK (cublasAxpyEx(handle, M * N, &alpha, CUDA_R_32F, Out, CUDA_R_16F, 1, fp_weight, CUDA_R_16F, 1, CUDA_R_32F));
// CUBLAS_CHECK( cublasNrm2Ex(handle, M * N , fp_weight, CUDA_R_16F, 1, &result, CUDA_R_16F, CUDA_R_32F));
// float re = (float)result / (float) norm;
// if (re > 0.05)
// printf("relative error = %.8f\n", (float)result / (float) norm);
// cublasDestroy(handle);
// printf("input A!!\n");
// std::vector<int8_t> a(128);
// cudaMemcpy(a.data(),int8_out, sizeof(int8_t) * 128,cudaMemcpyDeviceToHost);
// for (int i = 0 ; i < 10; ++i)
// printf("%d \t", a[i]);
// printf("A done!\n");
// printf("\n");
// printf("input W!!\n");
// std::vector<int8_t> w(128);
// cudaMemcpy(w.data(),W, sizeof(int8_t) * 128,cudaMemcpyDeviceToHost);
// for (int i = 0 ; i < 10; ++i)
// printf("%d \t", w[i]);
// printf("W done!\n");
// printf("\n");
}
else
{
// decode
// gemm_forward_cuda(
// M,
// N,
// K,
// Out,
// A, //activation
// int4_weight, // int4 weight // 采用fp16存
// scaling_factors, // scaling factors
// zeros,
// stream,
// actPtr
// );
w8_a16_gemm_forward_cuda(A, q_weight,
scaling_factors,
Out,
M,
N,
K,
stream);
// FILE *fp;
// half *tmp = (half *) malloc( sizeof(half) * M * N);
// cudaMemcpy(tmp, actPtr, sizeof(half) * M * N, cudaMemcpyDeviceToHost);
// fp = fopen ("awq.csv", "w");
// for (int i =0; i < 10; ++i){
// for (int j = 0; j < 20; ++j){
// fprintf (fp, "%4f\t", float( tmp[j + i * N] ));
// }
// fprintf (fp, "\n");
// }
// fclose(fp);
// exit(0);
// FILE *fp;
// half *tmp = (half *) malloc( sizeof(half) * M * K);
// cudaMemcpy(tmp, A, sizeof(half) * M * K, cudaMemcpyDeviceToHost);
// fp = fopen ("Input.csv", "w");
// for (int i = 0; i < M; ++i){
// for (int j = 0; j < K; ++j){
// fprintf (fp, "%4f\t", float( tmp[j + i * K] ));
// }
// fprintf (fp, "\n");
// }
// fclose(fp);exit(0);
// ------------------------
// FILE *fp;
// half *tmp = (half *) malloc( sizeof(half) * M * N);
// cudaMemcpy(tmp, scaling_factors, sizeof(half) * M * N, cudaMemcpyDeviceToHost);
// fp = fopen ("scalings.csv", "w");
// for (int i = 0; i < 2; ++i){
// for (int j = 0; j < 12288; ++j){
// fprintf (fp, "%4f\t", float( tmp[j + i * N] ));
// }
// fprintf (fp, "\n");
// }
// fclose(fp);
// int *tmp2 = (int *) malloc( sizeof(int) * 10 * N);
// cudaMemcpy(tmp2, int4_weight, sizeof(int) * 10 * N, cudaMemcpyDeviceToHost);
// fp = fopen ("int4_weight.csv", "w");
// for (int i = 0; i < 1; ++i){
// for (int j = 0; j < 20; ++j){
// fprintf (fp, "%d\t", int( tmp2[j + i * N] ));
// }
// fprintf (fp, "\n");
// }
// fclose(fp);
// int *tmp3 = (int *) malloc( sizeof(int) * N);
// cudaMemcpy(tmp3, zeros, sizeof(int) * N, cudaMemcpyDeviceToHost);
// fp = fopen ("zeros.csv", "w");
// for (int i = 0; i < 1; ++i){
// for (int j = 0; j < 20; ++j){
// fprintf (fp, "%d\t", int( tmp3[j + i * N] ));
// }
// fprintf (fp, "\n");
// }
// fclose(fp);
// exit(0);
// const int num_ind = 16;
// ExtractOutliersAndSetToZeros(M, K, A, fp_activation, ind, num_ind, stream);
// gemmfp16(fp_activation,fp_weight,Out, M, N, num_ind, handle, stream);
// int8quant(M, K, A, int8_out, scale_a, stream);
// int8FusedDequantizeCUDA(int8_out, W, scale_a,
// scale_b,
// Out, Out, M, N, K, reinterpret_cast<char*>(workspace),stream);
}
// if (std::is_same<T, floalue)
// {
// res = fmha_d64_fp32_default(stream, reinterpret_cast<CUdeviceptr>(Out), reinterpret_cast<CUdeviceptr>(L),
// reinterpret_cast<CUdeviceptr>(M), reinterpret_cast<CUdeviceptr>(Q), reinterpret_cast<CUdeviceptr>(K),
// reinterpret_cast<CUdeviceptr>(V), mSoftmaxScale, batchSize, mNumHeads, seqLen);
// }
// else
// {
// res = fmha_d64_fp16_default(stream, reinterpret_cast<CUdeviceptr>(Out), reinterpret_cast<CUdeviceptr>(L),
// reinterpret_cast<CUdeviceptr>(M), reinterpret_cast<CUdeviceptr>(Q), reinterpret_cast<CUdeviceptr>(K),
// reinterpret_cast<CUdeviceptr>(V), mSoftmaxScale, batchSize, mNumHeads, seqLen);
// }
return res;
}
int MixQPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
{
return enqueueImpl(inputDesc, outputDesc, inputs, outputs, workspace, stream);
}
return 1;
}
// IPluginV2Ext Methods
nvinfer1::DataType MixQPlugin::getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
{
assert(index == 0);
return nvinfer1::DataType::kHALF;
}
// IPluginV2 Methods
char const* MixQPlugin::getPluginType() const noexcept
{
return TRITON_FLASH_ATTENTION_PLUGIN_NAME;
}
char const* MixQPlugin::getPluginVersion() const noexcept
{
return TRITON_FLASH_ATTENTION_PLUGIN_VERSION;
}
int MixQPlugin::getNbOutputs() const noexcept
{
return 1;
}
int MixQPlugin::initialize() noexcept
{
// Load kernels generated by Triton AoT.
// load_fmha_d64_fp32();
// load_fmha_d64_fp16();
cublasCreate(&handle);
return 0;
}
void MixQPlugin::terminate() noexcept
{
// Unload kernels generated by Triton AoT.
// unload_fmha_d64_fp32();
// unload_fmha_d64_fp16();
}
size_t MixQPlugin::getSerializationSize() const noexcept
{
return sizeof(mm) + sizeof(mn) + sizeof(mk) ;
}
void MixQPlugin::serialize(void* buffer) const noexcept
{
char *d = static_cast<char*>(buffer), *a = d;
writeArg(d, mm);
writeArg(d, mn);
writeArg(d, mk);
}
// bool MixQPlugin::supportsFormatCombination(
// int pos, nvinfer1::PluginTensorDesc const* inOut,
// int nbInputs, int nbOutputs) noexcept
// {
// switch (pos)
// {
// case 0:
// // activation
// return inOut[pos].type == nvinfer1::DataType::kINT8 && inOut[pos].format == TensorFormat::kLINEAR;
// case 1:
// // weights
// // Weights stored in checkpoint must have int8 type
// return inOut[pos].type == nvinfer1::DataType::kINT8 && inOut[pos].format == TensorFormat::kLINEAR;
// case 2:
// // scales channels
// return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
// case 3:
// // scales tokens
// return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
// case 4:
// // out
// return inOut[pos].type == nvinfer1::DataType::kHALF && inOut[pos].format == TensorFormat::kLINEAR;
// default:
// // Never should be here
// assert(false);
// return false;
// }
// }
void MixQPlugin::destroy() noexcept
{
// This gets called when the network containing plugin is destroyed
delete this;
}
void MixQPlugin::setPluginNamespace(char const* libNamespace) noexcept
{
mNamespace = libNamespace;
}
char const* MixQPlugin::getPluginNamespace() const noexcept
{
return mNamespace.c_str();
}
///////////////
MixQPluginCreator::MixQPluginCreator()
{
// Fill PluginFieldCollection with PluginField arguments metadata
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("mm", nullptr, PluginFieldType::kINT32, -1));
mPluginAttributes.emplace_back(PluginField("mn", nullptr, PluginFieldType::kINT32, -1));
mPluginAttributes.emplace_back(PluginField("mk", nullptr, PluginFieldType::kINT32, -1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* MixQPluginCreator::getPluginName() const noexcept
{
return TRITON_FLASH_ATTENTION_PLUGIN_NAME;
}
char const* MixQPluginCreator::getPluginVersion() const noexcept
{
return TRITON_FLASH_ATTENTION_PLUGIN_VERSION;
}
PluginFieldCollection const* MixQPluginCreator::getFieldNames() noexcept
{
return &mFC;
}
IPluginV2* MixQPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
PluginField const* fields = fc->fields;
int m = 0;
int n = 0;
int k = 0;
// Read configurations from each fields
for (int i = 0; i < fc->nbFields; ++i)
{
char const* attrName = fields[i].name;
if (!strcmp(attrName, "m"))
{
assert(fields[i].type == PluginFieldType::kINT32);
m = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
else if (!strcmp(attrName, "n"))
{
assert(fields[i].type == PluginFieldType::kINT32);
n = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}else if (!strcmp(attrName, "k"))
{
assert(fields[i].type == PluginFieldType::kINT32);
k = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
}
}
try
{
auto* obj = new MixQPlugin(m, n, k);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
std::cerr << "Caught exception: " << e.what() << std::endl;
}
return nullptr;
}
IPluginV2* MixQPluginCreator::deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept
{
// This object will be deleted when the network is destroyed, which will
// call MixQPlugin::destroy()
try
{
auto* obj = new MixQPlugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
std::cerr << "Caught exception: " << e.what() << std::endl;
}
return nullptr;
}
void MixQPluginCreator::setPluginNamespace(char const* libNamespace) noexcept
{
mNamespace = libNamespace;
}
char const* MixQPluginCreator::getPluginNamespace() const noexcept
{
return mNamespace.c_str();
}
}