Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/layers/extensions/inference/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <torch/types.h>

// T maybe vector type, and may be different from t.dtype
template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion src/layers/extensions/inference/def.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#include <torch/extension.h>
#include <torch/types.h>

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
process_with_mask_cuda(const torch::Tensor& y, const torch::Tensor& scales, const torch::Tensor& means,
Expand Down
1 change: 1 addition & 0 deletions src/layers/extensions/inference/impl.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#include <torch/nn/functional.h>
#include "def.h"
namespace F = torch::nn::functional;

Expand Down
18 changes: 10 additions & 8 deletions src/layers/extensions/inference/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/ops/empty_like.h>
#include <torch/types.h>

#include "common.h"
#include "def.h"
Expand Down Expand Up @@ -121,10 +123,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
process_with_mask_cuda(const torch::Tensor& y, const torch::Tensor& scales, const torch::Tensor& means,
const torch::Tensor& mask, const float force_zero_thres)
{
auto y_res = torch::empty_like(y);
auto y_q = torch::empty_like(y);
auto y_hat = torch::empty_like(y);
auto s_hat = torch::empty_like(y);
auto y_res = at::empty_like(y);
auto y_q = at::empty_like(y);
auto y_hat = at::empty_like(y);
auto s_hat = at::empty_like(y);

if (y.dtype() == torch::kFloat32) {
process_with_mask_dispatcher<float, float4>(y_res, y_q, y_hat, s_hat, y, scales, means,
Expand Down Expand Up @@ -855,7 +857,7 @@ __forceinline__ void round_and_to_int8_dispatcher(torch::Tensor& z, torch::Tenso

torch::Tensor round_and_to_int8_cuda(torch::Tensor& z)
{
auto z_int8 = torch::empty_like(z, at::TensorOptions().dtype(torch::kInt8));
auto z_int8 = at::empty_like(z, at::TensorOptions().dtype(torch::kInt8));
if (z.dtype() == torch::kFloat32) {
round_and_to_int8_dispatcher<float, float4>(z, z_int8);
} else if (z.dtype() == torch::kFloat16) {
Expand Down Expand Up @@ -900,7 +902,7 @@ __forceinline__ void clamp_reciprocal_with_quant_dispatcher(torch::Tensor& q_dec
torch::Tensor clamp_reciprocal_with_quant_cuda(const torch::Tensor& q_dec, torch::Tensor& y,
const float min_val)
{
auto q_dec_clamp = torch::empty_like(q_dec);
auto q_dec_clamp = at::empty_like(q_dec);
if (q_dec.dtype() == torch::kFloat32) {
clamp_reciprocal_with_quant_dispatcher<float, float4>(q_dec_clamp, q_dec, y, min_val);
} else if (q_dec.dtype() == torch::kFloat16) {
Expand Down Expand Up @@ -1123,7 +1125,7 @@ torch::Tensor bias_wsilu_depthwise_conv2d_cuda(const torch::Tensor& x, const tor
const int H = x_shape[2];
const int W = x_shape[3];

auto out = torch::empty_like(x);
auto out = at::empty_like(x);

const int BLOCK_SIZE = 32;
const int THREAD_NUM_X = 16;
Expand Down