From f1404d823aba3e9362eb08d7279cc74a39cc7699 Mon Sep 17 00:00:00 2001 From: Leonid Verkhovtsev Date: Fri, 20 Dec 2024 10:20:58 +0300 Subject: [PATCH 1/3] replace 1d avg pool with torch mean Operations are equivalent in this context, because originally we use kernel size == -1 dim of the tensor. --- gigaam/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gigaam/model.py b/gigaam/model.py index bf58e79..5e60883 100644 --- a/gigaam/model.py +++ b/gigaam/model.py @@ -197,9 +197,9 @@ def forward_for_export(self, features: Tensor, feature_lengths: Tensor) -> Tenso Encoder-decoder forward to save model entirely in onnx format. """ encoded, _ = self.encoder(features, feature_lengths) - enc_pooled = nn.functional.avg_pool1d( - encoded, kernel_size=encoded.shape[-1].item() - ).squeeze(-1) + + enc_pooled = encoded.mean(dim=-1) + return nn.functional.softmax(self.head(enc_pooled)[0], dim=-1) def to_onnx(self, dir_path: str = ".") -> None: From ae91944f9a9c147cf7b789bdac255559b83bdf8f Mon Sep 17 00:00:00 2001 From: Leonid Verkhovtsev Date: Fri, 20 Dec 2024 11:13:07 +0300 Subject: [PATCH 2/3] add emo onnx inference refactor wrt to DRY part with loading and preprocessing wav file and infer through encoder. --- gigaam/onnx_utils.py | 64 +++++++++++++++++++++++++++++++------------- 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/gigaam/onnx_utils.py b/gigaam/onnx_utils.py index 53c5fc8..f4f1fdc 100644 --- a/gigaam/onnx_utils.py +++ b/gigaam/onnx_utils.py @@ -1,5 +1,5 @@ import warnings -from typing import List, Optional +from typing import Dict, List, Optional import numpy as np import onnxruntime as rt @@ -59,27 +59,10 @@ def transcribe_sample( sessions: List[rt.InferenceSession], preprocessor: Optional[gigaam.preprocess.FeatureExtractor] = None, ) -> str: - if preprocessor is None: - preprocessor = gigaam.preprocess.FeatureExtractor(SAMPLE_RATE, FEAT_IN) assert model_type in ["ctc", "rnnt"], "Only `ctc` and `rnnt` inference supported" - input_signal = gigaam.load_audio(wav_file) - input_signal = preprocessor( - input_signal.unsqueeze(0), torch.tensor([input_signal.shape[-1]]) - )[0].numpy() - - enc_sess = sessions[0] - enc_inputs = { - node.name: data - for (node, data) in zip( - enc_sess.get_inputs(), - [input_signal.astype(DTYPE), [input_signal.shape[-1]]], - ) - } - enc_features = enc_sess.run( - [node.name for node in enc_sess.get_outputs()], enc_inputs - )[0] + enc_features = encode_wav(preprocessor, sessions, wav_file) token_ids = [] prev_token = BLANK_IDX @@ -131,6 +114,39 @@ def transcribe_sample( return "".join(VOCAB[tok] for tok in token_ids) +def encode_wav(preprocessor, sessions, wav_file): + if preprocessor is None: + preprocessor = gigaam.preprocess.FeatureExtractor(SAMPLE_RATE, FEAT_IN) + + input_signal = gigaam.load_audio(wav_file) + input_signal = preprocessor( + input_signal.unsqueeze(0), torch.tensor([input_signal.shape[-1]]) + )[0].numpy() + enc_sess = sessions[0] + enc_inputs = { + node.name: data + for (node, data) in zip( + enc_sess.get_inputs(), + [input_signal.astype(DTYPE), [input_signal.shape[-1]]], + ) + } + enc_features = enc_sess.run( + [node.name for node in enc_sess.get_outputs()], enc_inputs + )[0] + return enc_features + + +def recognise_emotion( + wav_file: str, + sessions: List[rt.InferenceSession], + preprocessor: Optional[gigaam.preprocess.FeatureExtractor] = None, +) -> Dict[str, float]: + id2name = ["angry", "sad", "neutral", "positive"] + probs = encode_wav(preprocessor, sessions, wav_file) + + return {emo: conf for emo, conf in zip(id2name, probs.tolist())} + + def load_onnx_sessions( onnx_dir: str, model_type: str, @@ -150,6 +166,16 @@ def load_onnx_sessions( model_path, providers=["CPUExecutionProvider"], sess_options=opts ) ] + elif model_type == "emo": + assert model_version == "v1", "There is only v1 version available." + model_path = f"{onnx_dir}/{model_version}_{model_type}.onnx" + + sessions = [ + rt.InferenceSession( + model_path, providers=["CPUExecutionProvider"], sess_options=opts + ) + ] + else: pth = f"{onnx_dir}/{model_version}_{model_type}" enc_sess = rt.InferenceSession( From a8af18320fafd8795528203be3431901bcba5b59 Mon Sep 17 00:00:00 2001 From: Leonid Verkhovtsev Date: Fri, 20 Dec 2024 12:55:40 +0300 Subject: [PATCH 3/3] add conversion and inference example of EMO model. --- inference_example.ipynb | 960 ++++++++++++++++++++++------------------ 1 file changed, 528 insertions(+), 432 deletions(-) diff --git a/inference_example.ipynb b/inference_example.ipynb index 22a7464..cca988d 100644 --- a/inference_example.ipynb +++ b/inference_example.ipynb @@ -1,468 +1,564 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "Me66GFwT0ABG" - }, - "source": [ - "### Installing reqs and downloading examples" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PlM9MI70iTIp" - }, - "outputs": [], - "source": [ - "# If package is not installed\n", - "! pip install git+https://github.com/salute-developers/GigaAM.git" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GSE4HSfr1P0B" - }, - "outputs": [], - "source": [ - "# Downloading wavs for examples\n", - "!wget https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/example.wav\n", - "!wget https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/long_example.wav" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tIvec0280O64" - }, - "source": [ - "## Speech Recognition" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "xCgbSPkViZpF" - }, - "outputs": [], - "source": [ - "import os\n", - "import warnings\n", - "from typing import Dict\n", - "\n", - "import gigaam\n", - "\n", - "warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n", - "warnings.simplefilter(action=\"ignore\", category=UserWarning)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 71 - }, - "id": "R2aUkZbG8fJ6", - "outputId": "3172d43c-19cf-4b97-df5e-108e40d92144" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████| 888M/888M [02:27<00:00, 6.29MiB/s]\n" - ] - }, - { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - }, - "text/plain": [ - "'ничьих не требуя похвал счастлив уж я надеждой сладкой что дева с трепетом любви посмотрит может быть украдкой на песни грешные мои у лукоморья дуб зеленый'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = gigaam.load_model(\n", - " \"ctc\", # GigaAM-V2 CTC model\n", - " fp16_encoder=True, # to use fp16 encoder weights - GPU only\n", - " use_flash=False, # disable flash attention - colab does not support it\n", - ")\n", - "model.transcribe(\"example.wav\")" - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Me66GFwT0ABG" + }, + "source": [ + "### Installing reqs and downloading examples" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "PlM9MI70iTIp", + "ExecuteTime": { + "end_time": "2024-12-20T09:53:37.253270Z", + "start_time": "2024-12-20T09:53:37.250253Z" + } + }, + "source": [ + "# If package is not installed\n", + "! pip install git+https://github.com/salute-developers/GigaAM.git" + ], + "outputs": [], + "execution_count": 1 + }, + { + "cell_type": "code", + "metadata": { + "id": "GSE4HSfr1P0B", + "ExecuteTime": { + "end_time": "2024-12-20T09:53:42.915217Z", + "start_time": "2024-12-20T09:53:42.913196Z" + } + }, + "source": [ + "# Downloading wavs for examples\n", + "!wget https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/example.wav\n", + "!wget https://n-ws-q0bez.s3pd12.sbercloud.ru/b-ws-q0bez-jpv/GigaAM/long_example.wav" + ], + "outputs": [], + "execution_count": 2 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tIvec0280O64" + }, + "source": [ + "## Speech Recognition" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "xCgbSPkViZpF", + "ExecuteTime": { + "end_time": "2024-12-20T09:53:50.185838Z", + "start_time": "2024-12-20T09:53:48.798509Z" + } + }, + "source": [ + "import os\n", + "import warnings\n", + "from typing import Dict\n", + "\n", + "import gigaam\n", + "\n", + "warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n", + "warnings.simplefilter(action=\"ignore\", category=UserWarning)" + ], + "outputs": [], + "execution_count": 3 + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 71 }, + "id": "R2aUkZbG8fJ6", + "outputId": "3172d43c-19cf-4b97-df5e-108e40d92144", + "ExecuteTime": { + "end_time": "2024-12-20T09:53:53.351240Z", + "start_time": "2024-12-20T09:53:51.708228Z" + } + }, + "source": [ + "model = gigaam.load_model(\n", + " \"ctc\", # GigaAM-V2 CTC model\n", + " fp16_encoder=True, # to use fp16 encoder weights - GPU only\n", + " use_flash=False, # disable flash attention - colab does not support it\n", + ")\n", + "model.transcribe(\"example.wav\")" + ], + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 107 - }, - "id": "5nPc8flc1U3d", - "outputId": "cce9f875-62ad-4b48-a6a9-29701eb343ed" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████| 892M/892M [02:34<00:00, 6.07MiB/s]\n", - "WARNING:root:flash_attn is not supported on CPU. Disabling it...\n", - "WARNING:root:fp16 is not supported on CPU. Leaving fp32 weights...\n" - ] - }, - { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - }, - "text/plain": [ - "'ничьих не требуя похвал счастлив уж я надеждой сладкой что дева с трепетом любви посмотрит может быть украдкой на песни грешные мои у лукоморья дуб зеленый'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = gigaam.load_model(\n", - " \"rnnt\", # GigaAM-V2 RNNT model\n", - " device=\"cpu\", # CPU-inference\n", - ")\n", - "model.transcribe(\"example.wav\")" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:fp16 is not supported on CPU. Leaving fp32 weights...\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "78WAYQTa14Qs" - }, - "source": [ - "### Long-Form Speech Recognition" + "data": { + "text/plain": [ + "'ничьих не требуя похвал счастлив уж я надеждой сладкой что дева с трепетом любви посмотрит может быть украдкой на песни грешные мои у лукоморья дуб зеленый'" ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 4 + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 107 }, + "id": "5nPc8flc1U3d", + "outputId": "cce9f875-62ad-4b48-a6a9-29701eb343ed", + "ExecuteTime": { + "end_time": "2024-12-20T09:53:58.005264Z", + "start_time": "2024-12-20T09:53:56.178697Z" + } + }, + "source": [ + "model = gigaam.load_model(\n", + " \"rnnt\", # GigaAM-V2 RNNT model\n", + " device=\"cpu\", # CPU-inference\n", + ")\n", + "model.transcribe(\"example.wav\")" + ], + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hvQdrpqG39kQ" - }, - "outputs": [], - "source": [ - "!pip install gigaam[longform]" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:fp16 is not supported on CPU. Leaving fp32 weights...\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "riG4tjcH8fJ7" - }, - "source": [ - "* For long-form inference:\n", - " * generate [Hugging Face API token](https://huggingface.co/docs/hub/security-tokens)\n", - " * accept the conditions to access [pyannote/voice-activity-detection](https://huggingface.co/pyannote/voice-activity-detection) files and content\n", - " * accept the conditions to access [pyannote/segmentation](https://huggingface.co/pyannote/segmentation) files and content" + "data": { + "text/plain": [ + "'ничьих не требуя похвал счастлив уж я надеждой сладкой что дева с трепетом любви посмотрит может быть украдкой на песни грешные мои у лукоморья дуб зеленый'" ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 5 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "78WAYQTa14Qs" + }, + "source": [ + "### Long-Form Speech Recognition" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hvQdrpqG39kQ" + }, + "outputs": [], + "source": [ + "!pip install gigaam[longform]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "riG4tjcH8fJ7" + }, + "source": [ + "* For long-form inference:\n", + " * generate [Hugging Face API token](https://huggingface.co/docs/hub/security-tokens)\n", + " * accept the conditions to access [pyannote/voice-activity-detection](https://huggingface.co/pyannote/voice-activity-detection) files and content\n", + " * accept the conditions to access [pyannote/segmentation](https://huggingface.co/pyannote/segmentation) files and content" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GlpX1XOX4vGw" + }, + "outputs": [], + "source": [ + "os.environ[\"HF_TOKEN\"] = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "DI_tb_N918FS", + "outputId": "db002337-fdd4-4fba-95e8-802e010d5303" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GlpX1XOX4vGw" - }, - "outputs": [], - "source": [ - "os.environ[\"HF_TOKEN\"] = \"\"" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "[00:00:00 - 00:16:83]: вечерня отошла давно но в кельях тихо и темно уже и сам игумен строгий свои молитвы прекратил и кости ветхие склонил перекрестясь на одр убогий кругом и сон и тишина но церкви дверь отворена\n", + "[00:17:10 - 00:32:61]: трепещет луч лампады и тускло озаряет он и темную живопись икон и позлощенные оклады и раздается в тишине то тяжкий вздох то шепот важный и мрачно дремлет в вышине старинный свод\n", + "[00:32:95 - 00:49:33]: глухой и влажный стоят за клиросом чернец и грешник неподвижны оба и шепот их как глаз из гроба и грешник бледен как мертвец монах несчастный полно перестань\n", + "[00:49:82 - 01:05:74]: ужасна исповедь злодея заплачена тобою дань тому кто в злобе пламенея лукаво грешника блюдет и к вечной гибели ведет смирись опомнись время время раскаянье покров\n", + "[01:05:97 - 01:10:90]: я разрешу тебя грехов сложи мучительное бремя\n" + ] + } + ], + "source": [ + "model = gigaam.load_model(\"ctc\", use_flash=False)\n", + "recognition_result = model.transcribe_longform(\"long_example.wav\")\n", + "\n", + "for utterance in recognition_result:\n", + " transcription = utterance[\"transcription\"]\n", + " start, end = utterance[\"boundaries\"]\n", + " print(f\"[{gigaam.format_time(start)} - {gigaam.format_time(end)}]: {transcription}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ywEjYaAe3BMU" + }, + "source": [ + "## Emotion recognition" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "HWsJ2RuG3HNB", + "outputId": "6d1f8abb-9fb0-4c28-c5e3-c21413a22796", + "ExecuteTime": { + "end_time": "2024-12-20T09:54:11.149499Z", + "start_time": "2024-12-20T09:54:09.062327Z" + } + }, + "source": [ + "model = gigaam.load_model(\"emo\")\n", + "emotion2prob: Dict[str, int] = model.get_probs(\"example.wav\")\n", + "\n", + "print(\", \".join([f\"{emotion}: {prob:.3f}\" for emotion, prob in emotion2prob.items()]))" + ], + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "DI_tb_N918FS", - "outputId": "db002337-fdd4-4fba-95e8-802e010d5303" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[00:00:00 - 00:16:83]: вечерня отошла давно но в кельях тихо и темно уже и сам игумен строгий свои молитвы прекратил и кости ветхие склонил перекрестясь на одр убогий кругом и сон и тишина но церкви дверь отворена\n", - "[00:17:10 - 00:32:61]: трепещет луч лампады и тускло озаряет он и темную живопись икон и позлощенные оклады и раздается в тишине то тяжкий вздох то шепот важный и мрачно дремлет в вышине старинный свод\n", - "[00:32:95 - 00:49:33]: глухой и влажный стоят за клиросом чернец и грешник неподвижны оба и шепот их как глаз из гроба и грешник бледен как мертвец монах несчастный полно перестань\n", - "[00:49:82 - 01:05:74]: ужасна исповедь злодея заплачена тобою дань тому кто в злобе пламенея лукаво грешника блюдет и к вечной гибели ведет смирись опомнись время время раскаянье покров\n", - "[01:05:97 - 01:10:90]: я разрешу тебя грехов сложи мучительное бремя\n" - ] - } - ], - "source": [ - "model = gigaam.load_model(\"ctc\", use_flash=False)\n", - "recognition_result = model.transcribe_longform(\"long_example.wav\")\n", - "\n", - "for utterance in recognition_result:\n", - " transcription = utterance[\"transcription\"]\n", - " start, end = utterance[\"boundaries\"]\n", - " print(f\"[{gigaam.format_time(start)} - {gigaam.format_time(end)}]: {transcription}\")" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:fp16 is not supported on CPU. Leaving fp32 weights...\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "ywEjYaAe3BMU" - }, - "source": [ - "## Emotion recognition" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "angry: 0.000, sad: 0.002, neutral: 0.923, positive: 0.074\n" + ] + } + ], + "execution_count": 6 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1EBcRVcC3P2E" + }, + "source": [ + "## GigaAM embeddings" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "5y2mEAGU3TYN", + "outputId": "ae488719-df9f-4318-ecf7-56c95a87a4ac", + "ExecuteTime": { + "end_time": "2024-12-20T09:54:13.888322Z", + "start_time": "2024-12-20T09:54:12.086996Z" + } + }, + "source": [ + "# audio-only pretrained encoder\n", + "model = gigaam.load_model(\"ssl\", use_flash=False)\n", + "\n", + "emb, _ = model.embed_audio(\"example.wav\")\n", + "print(emb)" + ], + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "HWsJ2RuG3HNB", - "outputId": "6d1f8abb-9fb0-4c28-c5e3-c21413a22796" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████| 924M/924M [02:04<00:00, 7.78MiB/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "angry: 0.000, sad: 0.002, neutral: 0.923, positive: 0.074\n" - ] - } - ], - "source": [ - "model = gigaam.load_model(\"emo\")\n", - "emotion2prob: Dict[str, int] = model.get_probs(\"example.wav\")\n", - "\n", - "print(\", \".join([f\"{emotion}: {prob:.3f}\" for emotion, prob in emotion2prob.items()]))" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:fp16 is not supported on CPU. Leaving fp32 weights...\n" + ] }, { - "cell_type": "markdown", - "metadata": { - "id": "1EBcRVcC3P2E" - }, - "source": [ - "## GigaAM embeddings" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[-0.2815, 0.3648, 0.4501, ..., -0.4732, -0.4029, -0.2401],\n", + " [ 0.1610, -0.4995, -0.0576, ..., -0.6241, -0.2316, -0.2052],\n", + " [-1.1856, -1.0031, -0.6082, ..., -0.5144, -0.3742, -0.2669],\n", + " ...,\n", + " [ 0.0188, -0.3748, -0.8955, ..., 0.1732, 0.0576, 0.1311],\n", + " [ 0.2698, -0.0659, -0.5008, ..., -1.4415, -1.4807, -1.4491],\n", + " [-1.5653, -1.6697, -1.2832, ..., 0.5118, 0.4837, 0.0140]]],\n", + " grad_fn=)\n" + ] + } + ], + "execution_count": 7 + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "hW8GKJrl3tm7", + "outputId": "160f4e59-3020-4986-b35e-5c7fa6e105fd", + "ExecuteTime": { + "end_time": "2024-12-20T09:54:18.427764Z", + "start_time": "2024-12-20T09:54:16.707066Z" + } + }, + "source": [ + "# you also can embed audio with CTC- or RNNT-finetuned encoder\n", + "model = gigaam.load_model(\"ctc\", use_flash=False)\n", + "\n", + "emb, _ = model.embed_audio(\"example.wav\")\n", + "print(emb)" + ], + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "5y2mEAGU3TYN", - "outputId": "ae488719-df9f-4318-ecf7-56c95a87a4ac" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████| 887M/887M [03:21<00:00, 4.62MiB/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[-0.2829, 0.3638, 0.4520, ..., -0.4743, -0.4033, -0.2417],\n", - " [ 0.1611, -0.5006, -0.0584, ..., -0.6239, -0.2320, -0.2054],\n", - " [-1.1849, -1.0029, -0.6111, ..., -0.5137, -0.3737, -0.2654],\n", - " ...,\n", - " [ 0.0181, -0.3763, -0.8959, ..., 0.1716, 0.0556, 0.1298],\n", - " [ 0.2690, -0.0654, -0.5020, ..., -1.4432, -1.4827, -1.4490],\n", - " [-1.5650, -1.6693, -1.2834, ..., 0.5117, 0.4839, 0.0136]]],\n", - " device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "# audio-only pretrained encoder\n", - "model = gigaam.load_model(\"ssl\", use_flash=False)\n", - "\n", - "emb, _ = model.embed_audio(\"example.wav\")\n", - "print(emb)" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:fp16 is not supported on CPU. Leaving fp32 weights...\n" + ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "hW8GKJrl3tm7", - "outputId": "160f4e59-3020-4986-b35e-5c7fa6e105fd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([[[-1.0334, -0.2841, -0.3606, ..., -0.2859, -0.6947, -0.7006],\n", - " [-0.4317, -0.0140, -0.9296, ..., 0.1781, 0.2170, -0.0181],\n", - " [-0.9221, -1.1284, -0.6389, ..., -1.0664, -1.3304, -1.2421],\n", - " ...,\n", - " [ 0.5749, 0.5176, -0.0996, ..., 1.7497, 1.8691, 2.1302],\n", - " [-0.2919, -0.8087, -1.2554, ..., -0.7942, -0.7634, -0.7938],\n", - " [-1.8086, -2.1976, -2.4012, ..., 0.8310, 1.0165, 1.0165]]],\n", - " device='cuda:0', grad_fn=)\n" - ] - } - ], - "source": [ - "# you also can embed audio with CTC- or RNNT-finetuned encoder\n", - "model = gigaam.load_model(\"ctc\", use_flash=False)\n", - "\n", - "emb, _ = model.embed_audio(\"example.wav\")\n", - "print(emb)" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[[-1.2887, -0.4310, -0.5055, ..., -0.2268, -0.6392, -0.7267],\n", + " [-0.3969, 0.0309, -0.9771, ..., 0.2250, 0.2793, 0.0538],\n", + " [-0.8635, -1.1605, -0.9831, ..., -1.1848, -1.4710, -1.4190],\n", + " ...,\n", + " [ 0.7540, 0.5017, -0.1576, ..., 1.6958, 1.8045, 2.0394],\n", + " [-0.0138, -0.7460, -1.1923, ..., -0.8860, -0.8260, -0.8353],\n", + " [-1.6170, -1.9980, -2.2678, ..., 0.9713, 1.1955, 1.2046]]],\n", + " grad_fn=)\n" + ] + } + ], + "execution_count": 8 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "je4WAyLgz0Ua" + }, + "source": [ + "## Export to ONNX" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "7tvo_v_iz0Ua", + "outputId": "4db517e6-7e85-4c01-c4c3-0b21744b8988", + "ExecuteTime": { + "end_time": "2024-12-20T09:54:27.252464Z", + "start_time": "2024-12-20T09:54:22.040258Z" + } + }, + "source": [ + "onnx_dir = \"onnx\"\n", + "model_type = \"rnnt\" # or \"ctc\"\n", + "\n", + "model = gigaam.load_model(\n", + " model_type,\n", + " fp16_encoder=False, # only fp32 tensors\n", + " use_flash=False, # disable flash attention\n", + ")\n", + "model.to_onnx(dir_path=onnx_dir)" + ], + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "je4WAyLgz0Ua" - }, - "source": [ - "## Export to ONNX" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Succesfully ported onnx v2_rnnt_encoder to onnx/v2_rnnt_encoder.onnx.\n", + "Succesfully ported onnx v2_rnnt_decoder to onnx/v2_rnnt_decoder.onnx.\n", + "Succesfully ported onnx v2_rnnt_joint to onnx/v2_rnnt_joint.onnx.\n" + ] + } + ], + "execution_count": 9 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-20T09:54:37.783741Z", + "start_time": "2024-12-20T09:54:32.381577Z" + } + }, + "cell_type": "code", + "source": [ + "# Export emo model\n", + "onnx_dir = \"onnx\"\n", + "model_type = \"emo\"\n", + "\n", + "model = gigaam.load_model(\n", + " model_type,\n", + " fp16_encoder=False, # only fp32 tensors\n", + " use_flash=False, # disable flash attention\n", + " )\n", + "model.to_onnx(dir_path=onnx_dir)" + ], + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "7tvo_v_iz0Ua", - "outputId": "4db517e6-7e85-4c01-c4c3-0b21744b8988" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|███████████████████████████████████████| 892M/892M [03:20<00:00, 4.66MiB/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Succesfully ported onnx v2_rnnt_encoder to onnx/v2_rnnt_encoder.onnx.\n", - "Succesfully ported onnx v2_rnnt_decoder to onnx/v2_rnnt_decoder.onnx.\n", - "Succesfully ported onnx v2_rnnt_joint to onnx/v2_rnnt_joint.onnx.\n" - ] - } - ], - "source": [ - "onnx_dir = \"onnx\"\n", - "model_type = \"rnnt\" # or \"ctc\"\n", - "\n", - "model = gigaam.load_model(\n", - " model_type,\n", - " fp16_encoder=False, # only fp32 tensors\n", - " use_flash=False, # disable flash attention\n", - ")\n", - "model.to_onnx(dir_path=onnx_dir)" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Succesfully ported onnx v1_emo to onnx/v1_emo.onnx.\n" + ] + } + ], + "execution_count": 10 + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NL3W76cgz0Ua" + }, + "source": [ + "### ONNX inference" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 53 }, + "id": "j_rjvUeJz0Ua", + "outputId": "2f8a9e71-c5aa-4fb6-89cf-d251eeb7dce3", + "ExecuteTime": { + "end_time": "2024-12-20T09:54:44.869836Z", + "start_time": "2024-12-20T09:54:39.570067Z" + } + }, + "source": [ + "from gigaam.onnx_utils import load_onnx_sessions, transcribe_sample\n", + "\n", + "model_type = \"rnnt\"\n", + "\n", + "sessions = load_onnx_sessions(onnx_dir, model_type)\n", + "transcribe_sample(\"example.wav\", model_type, sessions)" + ], + "outputs": [ { - "cell_type": "markdown", - "metadata": { - "id": "NL3W76cgz0Ua" - }, - "source": [ - "### ONNX inference" + "data": { + "text/plain": [ + "'ничьих не требуя похвал счастлив уж я надеждой сладкой что дева с трепетом любви посмотрит может быть украдкой на песни грешные мои у лукоморья дуб зеленый'" ] - }, + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 11 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-20T09:54:48.977994Z", + "start_time": "2024-12-20T09:54:46.532454Z" + } + }, + "cell_type": "code", + "source": [ + "# Emo model onnx inference\n", + "from gigaam.onnx_utils import load_onnx_sessions, recognise_emotion\n", + "\n", + "model_type = \"emo\"\n", + "sessions = load_onnx_sessions(onnx_dir, model_type, model_version=\"v1\")\n", + "recognise_emotion(\"example.wav\", sessions)" + ], + "outputs": [ { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 53 - }, - "id": "j_rjvUeJz0Ua", - "outputId": "2f8a9e71-c5aa-4fb6-89cf-d251eeb7dce3" - }, - "outputs": [ - { - "data": { - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - }, - "text/plain": [ - "'ничьих не требуя похвал счастлив уж я надеждой сладкой что дева с трепетом любви посмотрит может быть украдкой на песни грешные мои у лукоморья дуб зеленый'" - ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from gigaam.onnx_utils import load_onnx_sessions, transcribe_sample\n", - "\n", - "sessions = load_onnx_sessions(onnx_dir, model_type)\n", - "transcribe_sample(\"example.wav\", model_type, sessions)" + "data": { + "text/plain": [ + "{'angry': 7.712716615060344e-05,\n", + " 'sad': 0.0021968623623251915,\n", + " 'neutral': 0.9233159422874451,\n", + " 'positive': 0.07441005110740662}" ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.18" - } + ], + "execution_count": 12 + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "name": "python3", + "language": "python" }, - "nbformat": 4, - "nbformat_minor": 0 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 0 }