import contextlib
import io
import wave
from enum import Enum
from typing import Any, Dict, List, Iterator, Tuple, Union
from subprocess import Popen, PIPE

from import LabelStudioMLBase
from import get_single_tag_keys, get_choice, is_skipped
import requests
import webrtcvad

_ffmpeg_settings = ['-vn', '-ac', '1', '-ar', '16000', '-f', 'wav', '-acodec', 'pcm_s16le']

class LabelStudioPipeline(LabelStudioMLBase):
    """Model loader."""
    def __init__(self, **kwargs):
        super(LabelStudioPipeline, self).__init__(**kwargs)

        assert len(self.parsed_label_config) == 1
        self.from_name, = list(self.parsed_label_config.items())[0]
        assert['type'] == 'Labels'

        # the model has only one textual input
        assert len(['to_name']) == 1
        assert len(['inputs']) == 1
        self.to_name =['to_name'][0]
        self.value =['inputs'][0]['value']

        self._model = WebRtcExtractor(mode=2)

    def predict(self, tasks: List, **kwargs) -> List:
        """This is where inference happens: model returns the list of predictions based on input list of tasks."""
        results = list()
        for task in tasks:
            uri = task["data"][self.value]

            with self._download(uri) as file_bytes:
                with self._transcode(file_bytes) as transcode_file:
                    voice_activity_detection: List[ExtractionItem] = self._model(transcode_file)

            segments_result = list()
            for index, segment in enumerate(voice_activity_detection):
                if segment.label == Label.no_activity and segment.duration >= 1000:
                            "from_name": self.from_name,
                            "to_name": self.to_name,
                            "type": "labels",
                            "value": {"start": round(segment.start / 1000, 2),
                                      "end": round(segment.end / 1000, 2),
                                      "labels": ["Other"]},

            results.append({"result": segments_result})
        return results

    def fit(self, completions: List, **kwargs) -> Dict[str, Any]:
        """Train model given list of completions, then returns dict with created links and resources."""

    def _transcode(file: io.BytesIO) -> io.BytesIO:
        """Transcode file represented as raw bytes.

        :param file: raw file bytes
        :param path: tempfile
        ffmpeg_command = ['ffmpeg', '-y', '-hide_banner', '-i', 'pipe:0', *_ffmpeg_settings, 'pipe:1']

        ffmpeg_process = Popen(ffmpeg_command, stdout=PIPE, stderr=PIPE, stdin=PIPE)
        ffmpeg_result, ffmpeg_error = ffmpeg_process.communicate(input=file.getvalue())
        return io.BytesIO(ffmpeg_result)

    def _download(uri: str) -> io.BytesIO:
        """Downloads media file from given URI and returns results as io.BytesIO.

        :param uri: URI to media file as string
        :return: raw file bytes
        response = requests.get(uri)
        file_bytes = io.BytesIO(response.content)

        return file_bytes

    def _duration(path: str) -> float:
        with, 'r') as f:
            frames = f.getnframes()
            rate = f.getframerate()
            duration = frames / float(rate)
            return duration

class Frame:
    """Represents a "frame" of audio data."""
    def __init__(self, data: bytes, start: float, duration: float) -> None: = data
        self.start = start
        self.duration = duration

class Label(Enum):
    speech = 1
    no_activity = 0

class ExtractionItem:
    def __init__(self, label: Label, start: int = 1, duration: int = -1):
        self.label: Label = label
        self.start: int = start
        self.duration: int = duration
        self.end: int = start + duration

    def view(self) -> Dict[str, Union[int, Label]]:
        return {'start': self.start, 'end': self.start + self.duration, 'label': self.label.value}

class WebRtcExtractor:

    def __init__(self, mode: int = 0) -> None:
        self._mode = mode
        self._vad = webrtcvad.Vad(mode=mode)

    def __call__(self, path: Union[str, io.BytesIO]) -> List[ExtractionItem]:
        audio, sample_rate, _ = read_wave(path)
        frames = frame_generator(30, audio, sample_rate)
        result: List[ExtractionItem] = list()
        for frame in frames:
            is_speech = self._vad.is_speech(, sample_rate)
            label = Label.speech if is_speech else Label.no_activity
                ExtractionItem(label, round(frame.start * 1000), int(frame.duration * 1000))
        result = refine_extraction(result)
        return result

def frame_generator(frame_duration_ms, audio, sample_rate) -> Iterator[Frame]:
    Generates audio frames from PCM audio data.
    Takes the desired frame duration in milliseconds, the PCM data, and
    the sample rate.
    Yields Frames of the requested duration.
    n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
    offset = 0
    timestamp = 0.0
    duration = (float(n) / sample_rate) / 2.0
    while offset + n <= len(audio):
        yield Frame(audio[offset:offset + n], timestamp, duration)
        timestamp += duration
        offset += n

def group_extractions(extraction_result: List[ExtractionItem]) -> Tuple[ExtractionItem, ExtractionItem]:
    first = last = extraction_result[0]
    for n in extraction_result[1:]:
        if n.label.value == last.label.value: # Part of the group, bump the end
            last = n
        else: # Not part of the group, yield current group and start a new
            yield first, last
            first = last = n
    yield first, last

def refine_extraction(extraction_result: List[ExtractionItem]) -> List[ExtractionItem]:
    refine_result: List[ExtractionItem] = list()
    for item in group_extractions(extraction_result):
        first, last = item
        label = first.label
        start = first.start
        duration = last.start + last.duration - first.start
            label, start, duration
    return refine_result

def read_wave(path: str) -> Tuple[bytearray, int, int]:
    Reads a *.wav file.
    Takes the path, and returns (PCM audio data, sample rate).
    with contextlib.closing(, 'rb')) as wf:
        num_channels = wf.getnchannels()
        assert num_channels == 1
        sample_width = wf.getsampwidth()
        assert sample_width == 2
        sample_rate = wf.getframerate()
        assert sample_rate in (8000, 16000, 32000, 48000)
        pcm_data = wf.readframes(wf.getnframes())
        return pcm_data, sample_rate, sample_width

