import contextlib
import io
import wave
from enum import Enum
from typing import Any, Dict, List, Iterator, Tuple, Union
from subprocess import Popen, PIPE

from label_studio.ml import LabelStudioMLBase
from label_studio.ml.utils import get_single_tag_keys, get_choice, is_skipped
import requests
import webrtcvad


_ffmpeg_settings = ['-vn', '-ac', '1', '-ar', '16000', '-f', 'wav', '-acodec', 'pcm_s16le']


class LabelStudioPipeline(LabelStudioMLBase):
    """Model loader."""
    def __init__(self, **kwargs):
        super(LabelStudioPipeline, self).__init__(**kwargs)

        assert len(self.parsed_label_config) == 1
        self.from_name, self.info = list(self.parsed_label_config.items())[0]
        assert self.info['type'] == 'Labels'

        # the model has only one textual input
        assert len(self.info['to_name']) == 1
        assert len(self.info['inputs']) == 1
        self.to_name = self.info['to_name'][0]
        self.value = self.info['inputs'][0]['value']

        self._model = WebRtcExtractor(mode=2)

    def predict(self, tasks: List, **kwargs) -> List:
        """This is where inference happens: model returns the list of predictions based on input list of tasks."""
        results = list()
        for task in tasks:
            uri = task["data"][self.value]

            with self._download(uri) as file_bytes:
                with self._transcode(file_bytes) as transcode_file:
                    voice_activity_detection: List[ExtractionItem] = self._model(transcode_file)

            segments_result = list()
            for index, segment in enumerate(voice_activity_detection):
                if segment.label == Label.no_activity and segment.duration >= 1000:
                    segments_result.append(
                        {
                            "from_name": self.from_name,
                            "to_name": self.to_name,
                            "type": "labels",
                            "value": {"start": round(segment.start / 1000, 2),
                                      "end": round(segment.end / 1000, 2),
                                      "labels": ["Other"]},
                        }
                    )

            results.append({"result": segments_result})
        return results

    def fit(self, completions: List, **kwargs) -> Dict[str, Any]:
        """Train model given list of completions, then returns dict with created links and resources."""
        pass

    @staticmethod
    def _transcode(file: io.BytesIO) -> io.BytesIO:
        """Transcode file represented as raw bytes.

        :param file: raw file bytes
        :param path: tempfile
        """
        ffmpeg_command = ['ffmpeg', '-y', '-hide_banner', '-i', 'pipe:0', *_ffmpeg_settings, 'pipe:1']

        ffmpeg_process = Popen(ffmpeg_command, stdout=PIPE, stderr=PIPE, stdin=PIPE)
        ffmpeg_result, ffmpeg_error = ffmpeg_process.communicate(input=file.getvalue())
        return io.BytesIO(ffmpeg_result)

    @staticmethod
    def _download(uri: str) -> io.BytesIO:
        """Downloads media file from given URI and returns results as io.BytesIO.

        :param uri: URI to media file as string
        :return: raw file bytes
        """
        response = requests.get(uri)
        file_bytes = io.BytesIO(response.content)

        return file_bytes

    @staticmethod
    def _duration(path: str) -> float:
        with wave.open(path, 'r') as f:
            frames = f.getnframes()
            rate = f.getframerate()
            duration = frames / float(rate)
            return duration


class Frame:
    """Represents a "frame" of audio data."""
    def __init__(self, data: bytes, start: float, duration: float) -> None:
        self.data = data
        self.start = start
        self.duration = duration


class Label(Enum):
    speech = 1
    no_activity = 0


class ExtractionItem:
    def __init__(self, label: Label, start: int = 1, duration: int = -1):
        self.label: Label = label
        self.start: int = start
        self.duration: int = duration
        self.end: int = start + duration

    @property
    def view(self) -> Dict[str, Union[int, Label]]:
        return {'start': self.start, 'end': self.start + self.duration, 'label': self.label.value}


class WebRtcExtractor:

    def __init__(self, mode: int = 0) -> None:
        self._mode = mode
        self._vad = webrtcvad.Vad(mode=mode)

    def __call__(self, path: Union[str, io.BytesIO]) -> List[ExtractionItem]:
        audio, sample_rate, _ = read_wave(path)
        frames = frame_generator(30, audio, sample_rate)
        result: List[ExtractionItem] = list()
        for frame in frames:
            is_speech = self._vad.is_speech(frame.data, sample_rate)
            label = Label.speech if is_speech else Label.no_activity
            result.append(
                ExtractionItem(label, round(frame.start * 1000), int(frame.duration * 1000))
            )
        result = refine_extraction(result)
        return result


def frame_generator(frame_duration_ms, audio, sample_rate) -> Iterator[Frame]:
    """
    Generates audio frames from PCM audio data.
    Takes the desired frame duration in milliseconds, the PCM data, and
    the sample rate.
    Yields Frames of the requested duration.
    """
    n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
    offset = 0
    timestamp = 0.0
    duration = (float(n) / sample_rate) / 2.0
    while offset + n <= len(audio):
        yield Frame(audio[offset:offset + n], timestamp, duration)
        timestamp += duration
        offset += n


def group_extractions(extraction_result: List[ExtractionItem]) -> Tuple[ExtractionItem, ExtractionItem]:
    first = last = extraction_result[0]
    for n in extraction_result[1:]:
        if n.label.value == last.label.value: # Part of the group, bump the end
            last = n
        else: # Not part of the group, yield current group and start a new
            yield first, last
            first = last = n
    yield first, last


def refine_extraction(extraction_result: List[ExtractionItem]) -> List[ExtractionItem]:
    refine_result: List[ExtractionItem] = list()
    for item in group_extractions(extraction_result):
        first, last = item
        label = first.label
        start = first.start
        duration = last.start + last.duration - first.start
        refine_result.append(ExtractionItem(
            label, start, duration
        ))
    return refine_result


def read_wave(path: str) -> Tuple[bytearray, int, int]:
    """
    Reads a *.wav file.
    Takes the path, and returns (PCM audio data, sample rate).
    """
    with contextlib.closing(wave.open(path, 'rb')) as wf:
        num_channels = wf.getnchannels()
        assert num_channels == 1
        sample_width = wf.getsampwidth()
        assert sample_width == 2
        sample_rate = wf.getframerate()
        assert sample_rate in (8000, 16000, 32000, 48000)
        pcm_data = wf.readframes(wf.getnframes())
        return pcm_data, sample_rate, sample_width

 Public
Share a link to this review

1.52% issue ratio

R58 Super() with redundant args

Calling super() doesn't need any arguments in most cases. super(Class, self) -> super().

R59 Not combining context managers

Context managers may be combined: with open('file1') as file1, open('file2') as file2:

Nice!
You did a great job avoiding this case. Many developers don't.
R4 Range-based iteration

Using len and range in python's for loop smells. Idiomatic python iteration looks like for element in collection. If you need element's index as well, use for i, element in enumerate(collection).

R28 Not using dataclass

Dataclasses let you get rid of many boilerplate code, most often the "init hell": def __init__(self, a): self.a = a. With dataclasses, it's all done automatically!


Create new review request