Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ authors = ["Oddity.ai Developers <hello@oddity.ai>"]
repository = "https://github.com/oddity-ai/async-tensorrt"
license = "MIT OR Apache-2.0"

[features]
lean = []

[dependencies]
async-cuda = "0.5.4"
cpp = "0.5"
Expand Down
5 changes: 5 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ fn main() {
#[cfg(not(windows))]
println!("cargo:rustc-link-search=/usr/local/tensorrt/lib64");

#[cfg(feature = "lean")]
println!("cargo:rustc-link-lib=nvinfer_lean");

#[cfg(not(feature = "lean"))]
println!("cargo:rustc-link-lib=nvinfer");
#[cfg(not(feature = "lean"))]
println!("cargo:rustc-link-lib=nvonnxparser");
}
13 changes: 13 additions & 0 deletions src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::ffi::memory::HostBuffer;
use crate::ffi::sync::engine::Engine as InnerEngine;
use crate::ffi::sync::engine::ExecutionContext as InnerExecutionContext;

pub use crate::ffi::sync::engine::TensorDataType;
pub use crate::ffi::sync::engine::TensorIoMode;

type Result<T> = std::result::Result<T, crate::error::Error>;
Expand Down Expand Up @@ -77,6 +78,18 @@ impl Engine {
pub fn tensor_io_mode(&self, tensor_name: &str) -> TensorIoMode {
self.inner.tensor_io_mode(tensor_name)
}

/// Get the data type of a tensor.
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#a569361fe7b7fced4b9c3f500346baca2)
///
/// # Arguments
///
/// * `tensor_name` - Tensor name.
#[inline(always)]
pub fn tensor_data_type(&self, tensor_name: &str) -> TensorDataType {
self.inner.tensor_data_type(tensor_name)
}
}

/// Context for executing inference using an engine.
Expand Down
42 changes: 42 additions & 0 deletions src/ffi/builder_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,34 @@ impl BuilderConfig {
self
}

/// Set the `kVERSION_COMPATIBLE` flag.
///
/// [TensorRT documentation for `setFlag`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder_config.html#ac9821504ae7a11769e48b0e62761837e)
/// [TensorRT documentation for `kVERSION_COMPATIBLE`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#abdc74c40fe7a0c3d05d2caeccfbc29c1a64917aa1f8d9238c555a46fa1d4e83b7)
pub fn with_version_compability(mut self) -> Self {
let internal = self.as_mut_ptr();
cpp!(unsafe [
internal as "void*"
] {
((IBuilderConfig*) internal)->setFlag(BuilderFlag::kVERSION_COMPATIBLE);
});
self
}

/// Set the `kEXCLUDE_LEAN_RUNTIME` flag.
///
/// [TensorRT documentation for `setFlag`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder_config.html#ac9821504ae7a11769e48b0e62761837e)
/// [TensorRT documentation for `kEXCLUDE_LEAN_RUNTIME`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#abdc74c40fe7a0c3d05d2caeccfbc29c1a239d59ead8393beeecaadd21ce3b3502)
pub fn with_exclude_lean_runtime(mut self) -> Self {
let internal = self.as_mut_ptr();
cpp!(unsafe [
internal as "void*"
] {
((IBuilderConfig*) internal)->setFlag(BuilderFlag::kEXCLUDE_LEAN_RUNTIME);
});
self
}

/// Set the `kFP16` flag.
///
/// [TensorRT documentation for `setFlag`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder_config.html#ac9821504ae7a11769e48b0e62761837e)
Expand All @@ -80,6 +108,20 @@ impl BuilderConfig {
self
}

/// Set the `kINT8` flag.
///
/// [TensorRT documentation for `setFlag`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder_config.html#ac9821504ae7a11769e48b0e62761837e)
/// [TensorRT documentation for `kINT8`](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/namespacenvinfer1.html#abdc74c40fe7a0c3d05d2caeccfbc29c1a69c1a4a69db0e50820cf63122f90ad09)
pub fn with_int8(mut self) -> Self {
let internal = self.as_mut_ptr();
cpp!(unsafe [
internal as "void*"
] {
((IBuilderConfig*) internal)->setFlag(BuilderFlag::kINT8);
});
self
}

/// Add an optimization profile.
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder_config.html#ab97fa40c85fa8afab65fc2659e38da82)
Expand Down
5 changes: 5 additions & 0 deletions src/ffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@ mod pre {
mod logger;
}

#[cfg(not(feature = "lean"))]
pub mod builder_config;

pub mod error;
pub mod memory;
pub mod network;

#[cfg(not(feature = "lean"))]
pub mod optimization_profile;

pub mod parser;
pub mod sync;

Expand Down
51 changes: 51 additions & 0 deletions src/ffi/sync/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@ impl Engine {
TensorIoMode::from_i32(tensor_io_mode)
}

pub fn tensor_data_type(&self, tensor_name: &str) -> TensorDataType {
let internal = self.as_ptr();
let tensor_name_cstr = std::ffi::CString::new(tensor_name).unwrap();
let tensor_name_ptr = tensor_name_cstr.as_ptr();
let tensor_data_type = cpp!(unsafe [
internal as "const void*",
tensor_name_ptr as "const char*"
] -> i32 as "std::int32_t" {
return (std::int32_t) ((const ICudaEngine*) internal)->getTensorDataType(tensor_name_ptr);
});
TensorDataType::from_i32(tensor_data_type)
}

#[inline(always)]
pub fn as_ptr(&self) -> *const std::ffi::c_void {
let Engine { internal, .. } = *self;
Expand Down Expand Up @@ -342,3 +355,41 @@ struct Dims {
pub nbDims: i32,
pub d: [i32; 8usize],
}

/// Tensor DataType.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum TensorDataType {
FLOAT,
HALF,
INT8,
INT32,
BOOL,
UINT8,
FP8,
BF16,
INT64,
INT4,
}

impl TensorDataType {
/// Create [`TensorDataType`] from `value`.
///
/// # Arguments
///
/// * `value` - Integer representation of IO mode.
fn from_i32(value: i32) -> Self {
match value {
0 => TensorDataType::FLOAT,
1 => TensorDataType::HALF,
2 => TensorDataType::INT8,
3 => TensorDataType::INT32,
4 => TensorDataType::BOOL,
5 => TensorDataType::UINT8,
6 => TensorDataType::FP8,
7 => TensorDataType::BF16,
8 => TensorDataType::INT64,
9 => TensorDataType::INT4,
_ => panic!("Unknown data type {}", value),
}
}
}
2 changes: 2 additions & 0 deletions src/ffi/sync/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(not(feature = "lean"))]
pub mod builder;

pub mod engine;
pub mod runtime;
10 changes: 10 additions & 0 deletions src/ffi/sync/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ impl Runtime {
result!(internal_engine, Engine::wrap(internal_engine, self))
}

pub fn set_engine_host_code_allowed(&mut self, allowed: bool) {
let internal = self.as_mut_ptr();
cpp!(unsafe [
internal as "void*",
allowed as "bool"
] {
((IRuntime*) internal)->setEngineHostCodeAllowed(allowed);
});
}

#[inline(always)]
pub fn as_ptr(&self) -> *const std::ffi::c_void {
self.addr
Expand Down
10 changes: 10 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#![recursion_limit = "256"]

#[cfg(not(feature = "lean"))]
pub mod builder;

pub mod engine;
pub mod error;
pub mod ffi;
Expand All @@ -9,12 +11,20 @@ pub mod runtime;
#[cfg(test)]
mod tests;

#[cfg(not(feature = "lean"))]
pub use builder::Builder;

pub use engine::{Engine, ExecutionContext};
pub use error::Error;

#[cfg(not(feature = "lean"))]
pub use ffi::builder_config::BuilderConfig;

pub use ffi::memory::HostBuffer;
pub use ffi::network::{NetworkDefinition, NetworkDefinitionCreationFlags, Tensor};

#[cfg(not(feature = "lean"))]
pub use ffi::optimization_profile::OptimizationProfile;

pub use ffi::parser::Parser;
pub use runtime::Runtime;
11 changes: 11 additions & 0 deletions src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ impl Runtime {
Self { inner }
}

/// Set whether the runtime is allowed to deserialize engines with host executable code.
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_runtime.html#a5a19c2524f74179cd9b781c6240eb3ce)
///
/// # Arguments
///
/// * `allowed` - Whether the runtime is allowed to deserialize engines with host executable code.
pub fn set_engine_host_code_allowed(&mut self, allowed: bool) {
self.inner.set_engine_host_code_allowed(allowed);
}

/// Deserialize engine from a plan (a [`HostBuffer`]).
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_runtime.html#ad0dc765e77cab99bfad901e47216a767)
Expand Down