Skip to content

model API

maai.model

Maai

Main wrapper class for running the MaAI model.

Handles audio input streams, model loading, audio processing, feature extraction, and VAP outputs in a background thread.

Source code in src/maai/model.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
class Maai():
    """Main wrapper class for running the MaAI model.

    Handles audio input streams, model loading, audio processing,
    feature extraction, and VAP outputs in a background thread.
    """

    BINS_P_NOW = [0, 1]
    BINS_PFUTURE = [2, 3]

    CALC_PROCESS_TIME_INTERVAL = 100

    def __init__(
        self,
        mode,
        lang: str,
        audio_ch1: Base,
        audio_ch2: Base,
        frame_rate: float = 10,
        context_len_sec: int = 20,
        device: str = "cpu",
        # num_channels: int = 2,
        cpc_model: str = os.path.expanduser("~/.cache/cpc/60k_epoch4-d0f474de.pt"),
        model_type: str = "normal",
        mimi_model_name: str = "kyutai/mimi",
        use_mimi_onnx: bool = True,
        mimi_onnx_precision: str = "fp32",
        mimi_onnx_fp32_path: str | None = None,
        mimi_onnx_fp32_meta_path: str | None = None,
        mimi_onnx_int8_path: str | None = None,
        mimi_onnx_int8_meta_path: str | None = None,
        mimi_local_onnx_fp32_path: str | None = None,
        mimi_local_onnx_fp32_meta_path: str | None = None,
        mimi_local_onnx_int8_path: str | None = None,
        mimi_local_onnx_int8_meta_path: str | None = None,
        mimi_onnx_cpu_intra_threads: int | None = None,
        mimi_onnx_cpu_inter_threads: int | None = None,
        cache_dir: str = None,
        force_download: bool = False,
        use_kv_cache: bool = True,
        local_model = None,
        return_p_bins: bool = False,
    ):
        """Initialize the Maai instance.

        Args:
            mode (str): Operational mode (e.g., 'vap', 'bc', 'nod').
            lang (str): Language setting (e.g., 'jp', 'en').
            audio_ch1 (Base): Audio input source for channel 1.
            audio_ch2 (Base): Audio input source for channel 2.
            frame_rate (float): Frame rate for processing audio.
            context_len_sec (int): Audio context length in seconds.
            device (str): Device to run the model on ('cpu', 'cuda').
            cpc_model (str): Path to CPC model weights.
            model_type (str): General model type (e.g., 'normal').
            mimi_model_name (str): Hugging Face model name for Mimi.
            use_mimi_onnx (bool): Whether to use ONNX backend for Mimi.
            mimi_onnx_precision (str): Precision for ONNX model ('fp32', 'int8').
            mimi_onnx_fp32_path (str | None): Path to FP32 ONNX model.
            mimi_onnx_fp32_meta_path (str | None): Path to FP32 ONNX meta.
            mimi_onnx_int8_path (str | None): Path to INT8 ONNX model.
            mimi_onnx_int8_meta_path (str | None): Path to INT8 ONNX meta.
            mimi_local_onnx_fp32_path (str | None): Path to local FP32 ONNX model.
            mimi_local_onnx_fp32_meta_path (str | None): Path to local FP32 ONNX meta.
            mimi_local_onnx_int8_path (str | None): Path to local INT8 ONNX model.
            mimi_local_onnx_int8_meta_path (str | None): Path to local INT8 ONNX meta.
            mimi_onnx_cpu_intra_threads (int | None): ONNX CPU intra-op threads.
            mimi_onnx_cpu_inter_threads (int | None): ONNX CPU inter-op threads.
            cache_dir (str): Cache directory for model weights.
            force_download (bool): Force download of model weights.
            use_kv_cache (bool): Whether to use KV caching during inference.
            local_model (str | None): Path to a local custom model file.
            return_p_bins (bool): Whether to return probability bins in 'vap' mode.
        """

        self.return_p_bins = bool(return_p_bins)

        encoder_type = resolve_encoder_type(model_type)

        conf = VapConfig()
        conf.frame_hz = float(frame_rate)
        conf.encoder_type = encoder_type
        conf.mimi_model_name = mimi_model_name
        conf.mimi_use_onnx = 1 if bool(use_mimi_onnx) else 0
        precision = str(mimi_onnx_precision).strip().lower()
        if precision not in {"fp32", "int8"}:
            raise ValueError("mimi_onnx_precision must be 'fp32' or 'int8'.")
        if str(device).startswith("cuda") and precision == "int8":
            raise ValueError("mimi_onnx_precision='int8' is not supported with CUDA. Use 'fp32' on CUDA.")
        conf.mimi_onnx_precision = precision
        if mimi_onnx_cpu_intra_threads is not None:
            conf.mimi_onnx_cpu_intra_threads = int(mimi_onnx_cpu_intra_threads)
        if mimi_onnx_cpu_inter_threads is not None:
            conf.mimi_onnx_cpu_inter_threads = int(mimi_onnx_cpu_inter_threads)
        fp32_onnx = mimi_local_onnx_fp32_path or mimi_onnx_fp32_path
        fp32_meta = mimi_local_onnx_fp32_meta_path or mimi_onnx_fp32_meta_path
        int8_onnx = mimi_local_onnx_int8_path or mimi_onnx_int8_path
        int8_meta = mimi_local_onnx_int8_meta_path or mimi_onnx_int8_meta_path
        if fp32_onnx is not None:
            conf.mimi_onnx_fp32_path = str(fp32_onnx)
        if fp32_meta is not None:
            conf.mimi_onnx_fp32_meta_path = str(fp32_meta)
        if int8_onnx is not None:
            conf.mimi_onnx_int8_path = str(int8_onnx)
        if int8_meta is not None:
            conf.mimi_onnx_int8_meta_path = str(int8_meta)
        if cache_dir is not None:
            conf.mimi_onnx_hf_cache_dir = cache_dir
        conf.mimi_onnx_hf_force_download = bool(force_download)

        # # Middle size model
        # if "middle" in lang:
        #     conf.dim = 256
        #     conf.channel_layers = 2
        #     conf.cross_layers = 6
        #     conf.num_heads = 8

        if mode in ["vap", "vap_mc"]:
            self.vap = VapGPT(conf)

        elif mode == "bc":
            self.vap = VapGPT_bc(conf)

        elif mode == "bc_2type":
            self.vap = VapGPT_bc_2type(conf)

        elif mode == "nod":
            self.vap = VapGPT_nod(conf)

        elif mode == "nod_para":
            conf.dropout = 0.2
            self.vap = VapGPT_nod_para(conf)

        elif mode == "vap_prompt":
            from .models.vap_prompt import VapGPT_prompt
            self.vap = VapGPT_prompt(conf)

        try:
            self.device = str(torch.device(device))
        except RuntimeError as exc:
            raise ValueError("Device must be a valid torch device string such as 'cpu', 'cuda', or 'cuda:0'.") from exc

        if not (self.device == "cpu" or self.device.startswith("cuda")):
            raise ValueError("Device must be 'cpu', 'cuda', or 'cuda:N'.")

        # Store the initial state of the model to check for unchanged parameters
        initial_state_dict = {name: param.clone() for name, param in self.vap.named_parameters()}

        nod_param_stats_from_file = None
        nod_count_thresholds_from_file = None
        if local_model is None:
            sd = load_vap_model(
                mode,
                frame_rate,
                context_len_sec,
                lang,
                device,
                cache_dir,
                force_download,
                model_type=model_type,
            )
            if (
                mode == "nod_para"
                and isinstance(sd, dict)
                and "state_dict" in sd
            ):
                nod_param_stats_from_file = sd.get("nod_param_stats")
                nod_count_thresholds_from_file = sd.get("nod_count_thresholds") or sd.get(
                    "nod_repetitions_thresholds"
                )
                sd = sd["state_dict"]
        else:
            print("Loading model from local file:", local_model)
            raw = torch.load(local_model, map_location="cpu")
            if isinstance(raw, dict):
                nod_param_stats_from_file = raw.get("nod_param_stats")
                nod_count_thresholds_from_file = raw.get(
                    "nod_count_thresholds"
                ) or raw.get("nod_repetitions_thresholds")
                if "state_dict" in raw:
                    sd = raw["state_dict"]
                else:
                    sd = raw
            else:
                sd = raw

        if hasattr(self.vap, "conf"):
            setattr(self.vap.conf, "runtime_device", self.device)
        self.vap.load_encoder(cpc_model=cpc_model)
        if mode == "nod_para" and isinstance(sd, dict):
            remapped_sd: dict = {}
            for _k, _v in sd.items():
                if _k.startswith("nod_count_head."):
                    remapped_sd[
                        "nod_repetitions_head." + _k[len("nod_count_head.") :]
                    ] = _v
                else:
                    remapped_sd[_k] = _v
            sd = remapped_sd
        self.vap.load_state_dict(sd, strict=False)

        if (
            mode == "nod_para"
            and nod_param_stats_from_file is not None
            and isinstance(nod_param_stats_from_file, dict)
        ):
            for _k in ("range_mean", "range_std", "speed_mean", "speed_std"):
                if _k in nod_param_stats_from_file:
                    self.vap.nod_param_stats[_k] = float(nod_param_stats_from_file[_k])
        if (
            mode == "nod_para"
            and nod_count_thresholds_from_file is not None
            and isinstance(nod_count_thresholds_from_file, dict)
        ):
            for _k in ("t0", "t1", "t2"):
                if _k in nod_count_thresholds_from_file:
                    self.vap.nod_repetitions_thresholds[_k] = float(
                        nod_count_thresholds_from_file[_k]
                    )
            if "t_swing" in nod_count_thresholds_from_file:
                self.vap.nod_swing_up_threshold = float(nod_count_thresholds_from_file["t_swing"])

        if conf.encoder_type == "cpc" and 'encoder.downsample.1.weight' in sd:
            self.vap.encoder1.downsample[1].weight = nn.Parameter(sd['encoder.downsample.1.weight'])
            self.vap.encoder1.downsample[1].bias = nn.Parameter(sd['encoder.downsample.1.bias'])
            self.vap.encoder1.downsample[2].ln.weight = nn.Parameter(sd['encoder.downsample.2.ln.weight'])
            self.vap.encoder1.downsample[2].ln.bias = nn.Parameter(sd['encoder.downsample.2.ln.bias'])

            self.vap.encoder2.downsample[1].weight = nn.Parameter(sd['encoder.downsample.1.weight'])
            self.vap.encoder2.downsample[1].bias = nn.Parameter(sd['encoder.downsample.1.bias'])
            self.vap.encoder2.downsample[2].ln.weight = nn.Parameter(sd['encoder.downsample.2.ln.weight'])
            self.vap.encoder2.downsample[2].ln.bias = nn.Parameter(sd['encoder.downsample.2.ln.bias'])

        # print(sd.keys())
        # input("Model loaded. Press Enter to continue...")
        if conf.encoder_type == "mimi" and 'encoder.frame_rate_conv.weight' in sd:
            self.vap.encoder1.frame_rate_conv.weight = nn.Parameter(sd['encoder.frame_rate_conv.weight'])
            self.vap.encoder1.frame_rate_conv.bias = nn.Parameter(sd['encoder.frame_rate_conv.bias'])

            self.vap.encoder2.frame_rate_conv.weight = nn.Parameter(sd['encoder.frame_rate_conv.weight'])
            self.vap.encoder2.frame_rate_conv.bias = nn.Parameter(sd['encoder.frame_rate_conv.bias'])

        # Check for parameters that were not updated from their initial values
        for name, param in self.vap.named_parameters():
            if name in initial_state_dict:
                if torch.equal(param.data, initial_state_dict[name].data):
                    # Exclude encoder parameters that are loaded separately
                    if not name.startswith('encoder.'):
                        print(f"Warning: Parameter '{name}' was not updated from its initial value.")

        self.vap.to(self.device)
        self.vap = self.vap.eval()

        self.mode = mode
        self.model_type = model_type
        self.encoder_type = encoder_type
        self._use_mimi_onnx = bool(use_mimi_onnx)
        self.mic1 = audio_ch1
        self.mic2 = audio_ch2

        # Always subscribe a dedicated queue for each mic if possible
        self._mic1_queue = self.mic1.subscribe()
        self._mic2_queue = self.mic2.subscribe()

        self.audio_contenxt_lim_sec = context_len_sec
        self.frame_rate = float(frame_rate)

        # Context length of the audio embeddings (depends on frame rate)
        self.audio_context_len = int(round(self.audio_contenxt_lim_sec * self.frame_rate))

        self.sampling_rate = 16000
        self.frame_contxt_padding = 320

        # Frame size
        # 10Hz -> 320 + 1600 samples
        # 12.5Hz -> 320 + 1280 samples
        # 20Hz -> 320 + 800 samples
        # 50Hz -> 320 + 320 samples
        self.audio_frame_size = int(round(self.sampling_rate / self.frame_rate)) + self.frame_contxt_padding

        self.current_x1_audio = []
        self.current_x2_audio = []

        self.result_p_now = 0.
        self.result_p_future = 0.
        self.result_p_bc_react = 0.
        self.result_p_bc_emo = 0.
        self.result_p_bc = 0.
        self.result_p_nod_short = 0.
        self.result_p_nod_long = 0.
        self.result_p_nod_long_p = 0.
        self.result_last_time = -1

        self.result_vad = [0., 0.]

        self.process_time_abs = -1

        self.e1_full = []
        self.e2_full = []

        self.list_process_time_context = []
        self.last_interval_time = time.time()

        self.result_dict_queue = queue.Queue()

        self.use_kv_cache = use_kv_cache
        self.vap_cache = None

        # Thread control
        self._stop_event = threading.Event()
        self._worker_thread = None

        self.reset_runtime_state()

    def reset_runtime_state(self):
        """Reset the internal audio buffers and cache states."""
        self.current_x1_audio = []
        self.current_x2_audio = []
        self.e1_full = []
        self.e2_full = []
        self.vap_cache = None
        self._skip_first_encoder_output = bool(self.encoder_type == "mimi" and self._use_mimi_onnx)

        for encoder_name in ["encoder1", "encoder2"]:
            encoder = getattr(self.vap, encoder_name, None)
            if encoder is not None and hasattr(encoder, "reset_streaming_state"):
                encoder.reset_streaming_state()

    # def _increase_mimi_chunk_threshold(self, attempted_num_samples: int):
    #     if self.encoder_type != "mimi":
    #         return

    #     previous_threshold = int(self.audio_frame_size)
    #     next_threshold = int(attempted_num_samples) + int(Base.FRAME_SIZE)
    #     if next_threshold <= self.audio_frame_size:
    #         next_threshold = self.audio_frame_size + int(Base.FRAME_SIZE)

    #     self.audio_frame_size = next_threshold
    #     if self.audio_frame_size != previous_threshold:
    #         print(
    #             f"[Info] Mimi streaming chunk threshold adjusted: {previous_threshold} -> {self.audio_frame_size} samples "
    #             f"({self.audio_frame_size / self.sampling_rate:.3f} sec)."
    #         )

    def worker(self):
        """Background loop to fetch audio from queues and run inference."""

        # Clear the queues at the start
        # This is to ensure that the queues are empty before starting the processing loop
        self._mic1_queue.queue.clear()
        self._mic2_queue.queue.clear()

        while not self._stop_event.is_set():
            x1 = self.mic1.get_audio_data(self._mic1_queue)
            x2 = self.mic2.get_audio_data(self._mic2_queue)

            if self._stop_event.is_set() or x1 is None or x2 is None:
                break

            self.process(x1, x2)

            # Clear the queues if they are too large
            if self._mic1_queue.qsize() > 100:
                self._mic1_queue.queue.clear()
                print("[Warning] Audio queue (channel 1) overflow detected. Clearing audio queues.")
            if self._mic2_queue.qsize() > 100:
                self._mic2_queue.queue.clear()
                print("[Warning] Audio queue (channel 2) overflow detected. Clearing audio queues.")

            # print(self._mic1_queue.qsize(), self._mic2_queue.qsize())

            # self._mic1_queue.queue.clear()
            # self._mic2_queue.queue.clear()

    def start(self):
        """Start the background audio fetching and processing thread."""

        self.reset_runtime_state()

        self.mic1.start()
        self.mic2.start()
        self._stop_event.clear()
        # Queue を空にしてからスレッド起動(スレッドに古いデータが届かないよう先にクリア)
        self._mic1_queue.queue.clear()
        self._mic2_queue.queue.clear()
        self._worker_thread = threading.Thread(target=self.worker, daemon=True)
        self._worker_thread.start()

    def stop(self, wait: bool = True, timeout: float = 2.0):
        """
        Safely stop the background processing thread.
        Args:
            wait (bool): If True, wait for the thread to finish.
            timeout (float): Max seconds to wait when joining.
        """
        self._stop_event.set()
        # Unblock blocking gets by pushing sentinels
        try:
            self._mic1_queue.put(None)
            self._mic2_queue.put(None)
        except Exception:
            pass
        if wait and self._worker_thread is not None and self._worker_thread.is_alive():
            self._worker_thread.join(timeout=timeout)

        # Best-effort queue cleanup
        try:
            self._mic1_queue.queue.clear()
            self._mic2_queue.queue.clear()
        except Exception:
            pass

        self.reset_runtime_state()

    def process(self, x1, x2):
        """Process a chunk of audio for both channels.

        Args:
            x1 (np.ndarray): Audio data chunk for channel 1.
            x2 (np.ndarray): Audio data chunk for channel 2.
        """

        time_start = time.time()

        # Initialize buffer if empty
        if len(self.current_x1_audio) == 0:
            self.current_x1_audio = np.zeros(self.frame_contxt_padding, dtype=np.float32)
        if len(self.current_x2_audio) == 0:
            self.current_x2_audio = np.zeros(self.frame_contxt_padding, dtype=np.float32)

        # x1 = x1.astype(np.float32, copy=False)
        # x2 = x2.astype(np.float32, copy=False)

        # Add to buffer
        self.current_x1_audio = np.concatenate([self.current_x1_audio, x1])
        self.current_x2_audio = np.concatenate([self.current_x2_audio, x2])

        # Return if the buffer does not have enough length
        if len(self.current_x1_audio) < self.audio_frame_size:
            return

        # Extract data for inference
        x1_proc = self.current_x1_audio
        x2_proc = self.current_x2_audio

        x1_dist = x1_proc[self.frame_contxt_padding:]
        x2_dist = x2_proc[self.frame_contxt_padding:]

        with torch.inference_mode():
            # Create tensors more efficiently with specified dtype and device
            x1_ = torch.from_numpy(x1_proc).float().unsqueeze(0).unsqueeze(0)
            x2_ = torch.from_numpy(x2_proc).float().unsqueeze(0).unsqueeze(0)

            # Move to device only once
            if self.device != 'cpu':
                x1_ = x1_.to(self.device, non_blocking=True)
                x2_ = x2_.to(self.device, non_blocking=True)

            # try:
            e1, e2 = self.vap.encode_audio(x1_, x2_)
            # except RuntimeError as exc:
            #     short_chunk_error = (
            #         self.encoder_type == "mimi"
            #         and "Calculated padded input size per channel" in str(exc)
            #         and "Kernel size can't be greater than actual input size" in str(exc)
            #     )
            #     if short_chunk_error:
            #         self._increase_mimi_chunk_threshold(len(self.current_x1_audio))
            #         self.process_time_abs = time.time()
            #         return
            #     raise

            if e1.shape[1] == 0 or e2.shape[1] == 0:
                # if self.encoder_type == "mimi":
                #     self._increase_mimi_chunk_threshold(len(self.current_x1_audio))
                # self.process_time_abs = time.time()
                if self.frame_contxt_padding > 0:
                    self.current_x1_audio = self.current_x1_audio[-self.frame_contxt_padding:].copy()
                    self.current_x2_audio = self.current_x2_audio[-self.frame_contxt_padding:].copy()
                else:
                    self.current_x1_audio = np.empty(0, dtype=np.float32)
                    self.current_x2_audio = np.empty(0, dtype=np.float32)
                print("[Warning] No audio features extracted. Skipping this frame.")
                return

            # Skip the first Mimi encoder output to avoid the startup-only mismatch
            # between ONNX and PyTorch cache warmup behavior.
            if self._skip_first_encoder_output:
                self._skip_first_encoder_output = False
                self.process_time_abs = time.time()
                if self.frame_contxt_padding > 0:
                    self.current_x1_audio = self.current_x1_audio[-self.frame_contxt_padding:].copy()
                    self.current_x2_audio = self.current_x2_audio[-self.frame_contxt_padding:].copy()
                else:
                    self.current_x1_audio = np.empty(0, dtype=np.float32)
                    self.current_x2_audio = np.empty(0, dtype=np.float32)
                return

            # Full model
            if not self.use_kv_cache:

                self.e1_full.append(e1)
                self.e2_full.append(e2)

                # More efficient context management
                if len(self.e1_full) > self.audio_context_len:
                    self.e1_full.pop(0)  # Remove from front instead of slicing
                if len(self.e2_full) > self.audio_context_len:
                    self.e2_full.pop(0)

                x1_full_ = torch.cat(self.e1_full, dim=1)
                x2_full_ = torch.cat(self.e2_full, dim=1)

                # Move to device only if necessary
                if self.device != 'cpu':
                    x1_full_ = x1_full_.to(self.device, non_blocking=True)
                    x2_full_ = x2_full_.to(self.device, non_blocking=True)

                out, _ = self.vap.forward(x1_full_, x2_full_, cache=None)

            # User KV cache
            elif self.use_kv_cache:

                out, self.vap_cache = self.vap.forward(e1, e2, cache=self.vap_cache)

                ## Trim all cache data in self.vap_cache so that the second-to-last dimension is self.audio_context_len - 1
                if self.vap_cache is not None:
                    new_cache = {}
                    for key, (k_list, v_list) in self.vap_cache.items():
                        new_k_list = []
                        new_v_list = []
                        for t in k_list:
                            if isinstance(t, torch.Tensor) and t.dim() >= 3:
                                new_k_list.append(t[..., -(self.audio_context_len - 1) :, :])
                            else:
                                new_k_list.append(t)
                        for t in v_list:
                            if isinstance(t, torch.Tensor) and t.dim() >= 3:
                                new_v_list.append(t[..., -(self.audio_context_len - 1) :, :])
                            else:
                                new_v_list.append(t)
                        new_cache[key] = (new_k_list, new_v_list)
                    self.vap_cache = new_cache

            # Pre-create result dict structure to avoid repeated key creation
            result_dict = {
                "t": time.time(),
                "x1": x1_dist.copy(),  # Only copy when necessary
                "x2": x2_dist.copy(),
            }

            # Use dictionary mapping for mode-specific outputs (faster than if-elif chain)
            mode_outputs = {
                "vap": lambda: {
                    "p_now": out['p_now'],
                    "p_future": out['p_future'],
                    "vad": out['vad'],
                    "p_bins": out['p_bins'],
                    "p_bins_now": out['p_bins_now'],
                    "p_bins_future": out['p_bins_future'],
                },
                "vap_mc": lambda: {
                    "p_now": out['p_now'],
                    "p_future": out['p_future'],
                    "vad": out['vad'],
                    "p_bins": out['p_bins'],
                    "p_bins_now": out['p_bins_now'],
                    "p_bins_future": out['p_bins_future'],
                },
                "vap_prompt": lambda: {
                    "p_now": out['p_now'],
                    "p_future": out['p_future'],
                    "vad": out['vad']
                },
                "bc": lambda: {
                    "p_bc": out['p_bc'],
                    "p_bc_detect": out['p_bc_detect']
                },
                "bc_2type": lambda: {
                    "p_bc_react": out['p_bc_react'],
                    "p_bc_emo": out['p_bc_emo']
                },
                "nod": lambda: {
                    "p_bc": out['p_bc'],
                    "p_nod_short": out['p_nod_short'],
                    "p_nod_long": out['p_nod_long'],
                    "p_nod_long_p": out['p_nod_long_p']
                },
                "nod_para": lambda: {
                    "p_nod": out["p_nod"],
                    "nod_repetitions": out["nod_repetitions"],
                    "nod_repetitions_pred": out["nod_repetitions_pred"],
                    "nod_range": out["nod_range"],
                    "nod_speed": out["nod_speed"],
                    "nod_swing_up": out["nod_swing_up"],
                    "nod_swing_up_pred": out["nod_swing_up_pred"],
                },
            }

            # Get mode-specific outputs
            if self.mode in mode_outputs:
                _out = mode_outputs[self.mode]()
                if not self.return_p_bins and self.mode in ("vap", "vap_mc"):
                    for _k in ("p_bins", "p_bins_now", "p_bins_future"):
                        _out.pop(_k, None)
                result_dict.update(_out)

            self.result_dict_queue.put(result_dict)

            time_process = time.time() - time_start
            self.list_process_time_context.append(time_process)

            # Performance monitoring (unchanged for clarity)
            if len(self.list_process_time_context) > self.CALC_PROCESS_TIME_INTERVAL:
                ave_proc_time = np.mean(self.list_process_time_context)  # np.mean is faster than np.average
                num_process_frame = len(self.list_process_time_context) / (time.time() - self.last_interval_time)
                self.last_interval_time = time.time()

                perf_message = f'[{self.mode}] Average processing time: {ave_proc_time:.5f} [sec], #process/sec: {num_process_frame:.3f}'
                if self.encoder_type == "mimi":
                    perf_message += f', chunk_samples: {self.audio_frame_size}'
                print(perf_message)
                self.list_process_time_context.clear()  # clear() is faster than = []

            self.process_time_abs = time.time()

        # Keep only the last samples in the buffer (use views for efficiency)
        if self.frame_contxt_padding > 0:
            self.current_x1_audio = self.current_x1_audio[-self.frame_contxt_padding:].copy()
            self.current_x2_audio = self.current_x2_audio[-self.frame_contxt_padding:].copy()
        else:
            self.current_x1_audio = np.empty(0, dtype=np.float32)
            self.current_x2_audio = np.empty(0, dtype=np.float32)

    def get_result(self):
        """Retrieve the latest inference result from the queue.

        Returns:
            dict: The latest result containing predictions and raw audio data.
        """
        return self.result_dict_queue.get()

    def set_prompt_ch1(self, prompt: str):
        """
        Set the prompt text for speaker 1. This method is only available for the 'vap_prompt' mode.

        Args:
            prompt (str): The prompt text for speaker 1.
        """

        if self.mode != "vap_prompt":
            raise ValueError("This method is only available for the 'vap_prompt' mode.")

        self.vap.set_prompt_ch1(prompt, self.device)

    def set_prompt_ch2(self, prompt: str):
        """
        Set the prompt text for speaker 2. This method is only available for the 'vap_prompt' mode.

        Args:
            prompt (str): The prompt text for speaker 2.
        """

        if self.mode != "vap_prompt":
            raise ValueError("This method is only available for the 'vap_prompt' mode.")

        self.vap.set_prompt_ch2(prompt, self.device)

__init__(mode, lang, audio_ch1, audio_ch2, frame_rate=10, context_len_sec=20, device='cpu', cpc_model=os.path.expanduser('~/.cache/cpc/60k_epoch4-d0f474de.pt'), model_type='normal', mimi_model_name='kyutai/mimi', use_mimi_onnx=True, mimi_onnx_precision='fp32', mimi_onnx_fp32_path=None, mimi_onnx_fp32_meta_path=None, mimi_onnx_int8_path=None, mimi_onnx_int8_meta_path=None, mimi_local_onnx_fp32_path=None, mimi_local_onnx_fp32_meta_path=None, mimi_local_onnx_int8_path=None, mimi_local_onnx_int8_meta_path=None, mimi_onnx_cpu_intra_threads=None, mimi_onnx_cpu_inter_threads=None, cache_dir=None, force_download=False, use_kv_cache=True, local_model=None, return_p_bins=False)

Initialize the Maai instance.

Parameters:

Name Type Description Default
mode str

Operational mode (e.g., 'vap', 'bc', 'nod').

required
lang str

Language setting (e.g., 'jp', 'en').

required
audio_ch1 Base

Audio input source for channel 1.

required
audio_ch2 Base

Audio input source for channel 2.

required
frame_rate float

Frame rate for processing audio.

10
context_len_sec int

Audio context length in seconds.

20
device str

Device to run the model on ('cpu', 'cuda').

'cpu'
cpc_model str

Path to CPC model weights.

expanduser('~/.cache/cpc/60k_epoch4-d0f474de.pt')
model_type str

General model type (e.g., 'normal').

'normal'
mimi_model_name str

Hugging Face model name for Mimi.

'kyutai/mimi'
use_mimi_onnx bool

Whether to use ONNX backend for Mimi.

True
mimi_onnx_precision str

Precision for ONNX model ('fp32', 'int8').

'fp32'
mimi_onnx_fp32_path str | None

Path to FP32 ONNX model.

None
mimi_onnx_fp32_meta_path str | None

Path to FP32 ONNX meta.

None
mimi_onnx_int8_path str | None

Path to INT8 ONNX model.

None
mimi_onnx_int8_meta_path str | None

Path to INT8 ONNX meta.

None
mimi_local_onnx_fp32_path str | None

Path to local FP32 ONNX model.

None
mimi_local_onnx_fp32_meta_path str | None

Path to local FP32 ONNX meta.

None
mimi_local_onnx_int8_path str | None

Path to local INT8 ONNX model.

None
mimi_local_onnx_int8_meta_path str | None

Path to local INT8 ONNX meta.

None
mimi_onnx_cpu_intra_threads int | None

ONNX CPU intra-op threads.

None
mimi_onnx_cpu_inter_threads int | None

ONNX CPU inter-op threads.

None
cache_dir str

Cache directory for model weights.

None
force_download bool

Force download of model weights.

False
use_kv_cache bool

Whether to use KV caching during inference.

True
local_model str | None

Path to a local custom model file.

None
return_p_bins bool

Whether to return probability bins in 'vap' mode.

False
Source code in src/maai/model.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def __init__(
    self,
    mode,
    lang: str,
    audio_ch1: Base,
    audio_ch2: Base,
    frame_rate: float = 10,
    context_len_sec: int = 20,
    device: str = "cpu",
    # num_channels: int = 2,
    cpc_model: str = os.path.expanduser("~/.cache/cpc/60k_epoch4-d0f474de.pt"),
    model_type: str = "normal",
    mimi_model_name: str = "kyutai/mimi",
    use_mimi_onnx: bool = True,
    mimi_onnx_precision: str = "fp32",
    mimi_onnx_fp32_path: str | None = None,
    mimi_onnx_fp32_meta_path: str | None = None,
    mimi_onnx_int8_path: str | None = None,
    mimi_onnx_int8_meta_path: str | None = None,
    mimi_local_onnx_fp32_path: str | None = None,
    mimi_local_onnx_fp32_meta_path: str | None = None,
    mimi_local_onnx_int8_path: str | None = None,
    mimi_local_onnx_int8_meta_path: str | None = None,
    mimi_onnx_cpu_intra_threads: int | None = None,
    mimi_onnx_cpu_inter_threads: int | None = None,
    cache_dir: str = None,
    force_download: bool = False,
    use_kv_cache: bool = True,
    local_model = None,
    return_p_bins: bool = False,
):
    """Initialize the Maai instance.

    Args:
        mode (str): Operational mode (e.g., 'vap', 'bc', 'nod').
        lang (str): Language setting (e.g., 'jp', 'en').
        audio_ch1 (Base): Audio input source for channel 1.
        audio_ch2 (Base): Audio input source for channel 2.
        frame_rate (float): Frame rate for processing audio.
        context_len_sec (int): Audio context length in seconds.
        device (str): Device to run the model on ('cpu', 'cuda').
        cpc_model (str): Path to CPC model weights.
        model_type (str): General model type (e.g., 'normal').
        mimi_model_name (str): Hugging Face model name for Mimi.
        use_mimi_onnx (bool): Whether to use ONNX backend for Mimi.
        mimi_onnx_precision (str): Precision for ONNX model ('fp32', 'int8').
        mimi_onnx_fp32_path (str | None): Path to FP32 ONNX model.
        mimi_onnx_fp32_meta_path (str | None): Path to FP32 ONNX meta.
        mimi_onnx_int8_path (str | None): Path to INT8 ONNX model.
        mimi_onnx_int8_meta_path (str | None): Path to INT8 ONNX meta.
        mimi_local_onnx_fp32_path (str | None): Path to local FP32 ONNX model.
        mimi_local_onnx_fp32_meta_path (str | None): Path to local FP32 ONNX meta.
        mimi_local_onnx_int8_path (str | None): Path to local INT8 ONNX model.
        mimi_local_onnx_int8_meta_path (str | None): Path to local INT8 ONNX meta.
        mimi_onnx_cpu_intra_threads (int | None): ONNX CPU intra-op threads.
        mimi_onnx_cpu_inter_threads (int | None): ONNX CPU inter-op threads.
        cache_dir (str): Cache directory for model weights.
        force_download (bool): Force download of model weights.
        use_kv_cache (bool): Whether to use KV caching during inference.
        local_model (str | None): Path to a local custom model file.
        return_p_bins (bool): Whether to return probability bins in 'vap' mode.
    """

    self.return_p_bins = bool(return_p_bins)

    encoder_type = resolve_encoder_type(model_type)

    conf = VapConfig()
    conf.frame_hz = float(frame_rate)
    conf.encoder_type = encoder_type
    conf.mimi_model_name = mimi_model_name
    conf.mimi_use_onnx = 1 if bool(use_mimi_onnx) else 0
    precision = str(mimi_onnx_precision).strip().lower()
    if precision not in {"fp32", "int8"}:
        raise ValueError("mimi_onnx_precision must be 'fp32' or 'int8'.")
    if str(device).startswith("cuda") and precision == "int8":
        raise ValueError("mimi_onnx_precision='int8' is not supported with CUDA. Use 'fp32' on CUDA.")
    conf.mimi_onnx_precision = precision
    if mimi_onnx_cpu_intra_threads is not None:
        conf.mimi_onnx_cpu_intra_threads = int(mimi_onnx_cpu_intra_threads)
    if mimi_onnx_cpu_inter_threads is not None:
        conf.mimi_onnx_cpu_inter_threads = int(mimi_onnx_cpu_inter_threads)
    fp32_onnx = mimi_local_onnx_fp32_path or mimi_onnx_fp32_path
    fp32_meta = mimi_local_onnx_fp32_meta_path or mimi_onnx_fp32_meta_path
    int8_onnx = mimi_local_onnx_int8_path or mimi_onnx_int8_path
    int8_meta = mimi_local_onnx_int8_meta_path or mimi_onnx_int8_meta_path
    if fp32_onnx is not None:
        conf.mimi_onnx_fp32_path = str(fp32_onnx)
    if fp32_meta is not None:
        conf.mimi_onnx_fp32_meta_path = str(fp32_meta)
    if int8_onnx is not None:
        conf.mimi_onnx_int8_path = str(int8_onnx)
    if int8_meta is not None:
        conf.mimi_onnx_int8_meta_path = str(int8_meta)
    if cache_dir is not None:
        conf.mimi_onnx_hf_cache_dir = cache_dir
    conf.mimi_onnx_hf_force_download = bool(force_download)

    # # Middle size model
    # if "middle" in lang:
    #     conf.dim = 256
    #     conf.channel_layers = 2
    #     conf.cross_layers = 6
    #     conf.num_heads = 8

    if mode in ["vap", "vap_mc"]:
        self.vap = VapGPT(conf)

    elif mode == "bc":
        self.vap = VapGPT_bc(conf)

    elif mode == "bc_2type":
        self.vap = VapGPT_bc_2type(conf)

    elif mode == "nod":
        self.vap = VapGPT_nod(conf)

    elif mode == "nod_para":
        conf.dropout = 0.2
        self.vap = VapGPT_nod_para(conf)

    elif mode == "vap_prompt":
        from .models.vap_prompt import VapGPT_prompt
        self.vap = VapGPT_prompt(conf)

    try:
        self.device = str(torch.device(device))
    except RuntimeError as exc:
        raise ValueError("Device must be a valid torch device string such as 'cpu', 'cuda', or 'cuda:0'.") from exc

    if not (self.device == "cpu" or self.device.startswith("cuda")):
        raise ValueError("Device must be 'cpu', 'cuda', or 'cuda:N'.")

    # Store the initial state of the model to check for unchanged parameters
    initial_state_dict = {name: param.clone() for name, param in self.vap.named_parameters()}

    nod_param_stats_from_file = None
    nod_count_thresholds_from_file = None
    if local_model is None:
        sd = load_vap_model(
            mode,
            frame_rate,
            context_len_sec,
            lang,
            device,
            cache_dir,
            force_download,
            model_type=model_type,
        )
        if (
            mode == "nod_para"
            and isinstance(sd, dict)
            and "state_dict" in sd
        ):
            nod_param_stats_from_file = sd.get("nod_param_stats")
            nod_count_thresholds_from_file = sd.get("nod_count_thresholds") or sd.get(
                "nod_repetitions_thresholds"
            )
            sd = sd["state_dict"]
    else:
        print("Loading model from local file:", local_model)
        raw = torch.load(local_model, map_location="cpu")
        if isinstance(raw, dict):
            nod_param_stats_from_file = raw.get("nod_param_stats")
            nod_count_thresholds_from_file = raw.get(
                "nod_count_thresholds"
            ) or raw.get("nod_repetitions_thresholds")
            if "state_dict" in raw:
                sd = raw["state_dict"]
            else:
                sd = raw
        else:
            sd = raw

    if hasattr(self.vap, "conf"):
        setattr(self.vap.conf, "runtime_device", self.device)
    self.vap.load_encoder(cpc_model=cpc_model)
    if mode == "nod_para" and isinstance(sd, dict):
        remapped_sd: dict = {}
        for _k, _v in sd.items():
            if _k.startswith("nod_count_head."):
                remapped_sd[
                    "nod_repetitions_head." + _k[len("nod_count_head.") :]
                ] = _v
            else:
                remapped_sd[_k] = _v
        sd = remapped_sd
    self.vap.load_state_dict(sd, strict=False)

    if (
        mode == "nod_para"
        and nod_param_stats_from_file is not None
        and isinstance(nod_param_stats_from_file, dict)
    ):
        for _k in ("range_mean", "range_std", "speed_mean", "speed_std"):
            if _k in nod_param_stats_from_file:
                self.vap.nod_param_stats[_k] = float(nod_param_stats_from_file[_k])
    if (
        mode == "nod_para"
        and nod_count_thresholds_from_file is not None
        and isinstance(nod_count_thresholds_from_file, dict)
    ):
        for _k in ("t0", "t1", "t2"):
            if _k in nod_count_thresholds_from_file:
                self.vap.nod_repetitions_thresholds[_k] = float(
                    nod_count_thresholds_from_file[_k]
                )
        if "t_swing" in nod_count_thresholds_from_file:
            self.vap.nod_swing_up_threshold = float(nod_count_thresholds_from_file["t_swing"])

    if conf.encoder_type == "cpc" and 'encoder.downsample.1.weight' in sd:
        self.vap.encoder1.downsample[1].weight = nn.Parameter(sd['encoder.downsample.1.weight'])
        self.vap.encoder1.downsample[1].bias = nn.Parameter(sd['encoder.downsample.1.bias'])
        self.vap.encoder1.downsample[2].ln.weight = nn.Parameter(sd['encoder.downsample.2.ln.weight'])
        self.vap.encoder1.downsample[2].ln.bias = nn.Parameter(sd['encoder.downsample.2.ln.bias'])

        self.vap.encoder2.downsample[1].weight = nn.Parameter(sd['encoder.downsample.1.weight'])
        self.vap.encoder2.downsample[1].bias = nn.Parameter(sd['encoder.downsample.1.bias'])
        self.vap.encoder2.downsample[2].ln.weight = nn.Parameter(sd['encoder.downsample.2.ln.weight'])
        self.vap.encoder2.downsample[2].ln.bias = nn.Parameter(sd['encoder.downsample.2.ln.bias'])

    # print(sd.keys())
    # input("Model loaded. Press Enter to continue...")
    if conf.encoder_type == "mimi" and 'encoder.frame_rate_conv.weight' in sd:
        self.vap.encoder1.frame_rate_conv.weight = nn.Parameter(sd['encoder.frame_rate_conv.weight'])
        self.vap.encoder1.frame_rate_conv.bias = nn.Parameter(sd['encoder.frame_rate_conv.bias'])

        self.vap.encoder2.frame_rate_conv.weight = nn.Parameter(sd['encoder.frame_rate_conv.weight'])
        self.vap.encoder2.frame_rate_conv.bias = nn.Parameter(sd['encoder.frame_rate_conv.bias'])

    # Check for parameters that were not updated from their initial values
    for name, param in self.vap.named_parameters():
        if name in initial_state_dict:
            if torch.equal(param.data, initial_state_dict[name].data):
                # Exclude encoder parameters that are loaded separately
                if not name.startswith('encoder.'):
                    print(f"Warning: Parameter '{name}' was not updated from its initial value.")

    self.vap.to(self.device)
    self.vap = self.vap.eval()

    self.mode = mode
    self.model_type = model_type
    self.encoder_type = encoder_type
    self._use_mimi_onnx = bool(use_mimi_onnx)
    self.mic1 = audio_ch1
    self.mic2 = audio_ch2

    # Always subscribe a dedicated queue for each mic if possible
    self._mic1_queue = self.mic1.subscribe()
    self._mic2_queue = self.mic2.subscribe()

    self.audio_contenxt_lim_sec = context_len_sec
    self.frame_rate = float(frame_rate)

    # Context length of the audio embeddings (depends on frame rate)
    self.audio_context_len = int(round(self.audio_contenxt_lim_sec * self.frame_rate))

    self.sampling_rate = 16000
    self.frame_contxt_padding = 320

    # Frame size
    # 10Hz -> 320 + 1600 samples
    # 12.5Hz -> 320 + 1280 samples
    # 20Hz -> 320 + 800 samples
    # 50Hz -> 320 + 320 samples
    self.audio_frame_size = int(round(self.sampling_rate / self.frame_rate)) + self.frame_contxt_padding

    self.current_x1_audio = []
    self.current_x2_audio = []

    self.result_p_now = 0.
    self.result_p_future = 0.
    self.result_p_bc_react = 0.
    self.result_p_bc_emo = 0.
    self.result_p_bc = 0.
    self.result_p_nod_short = 0.
    self.result_p_nod_long = 0.
    self.result_p_nod_long_p = 0.
    self.result_last_time = -1

    self.result_vad = [0., 0.]

    self.process_time_abs = -1

    self.e1_full = []
    self.e2_full = []

    self.list_process_time_context = []
    self.last_interval_time = time.time()

    self.result_dict_queue = queue.Queue()

    self.use_kv_cache = use_kv_cache
    self.vap_cache = None

    # Thread control
    self._stop_event = threading.Event()
    self._worker_thread = None

    self.reset_runtime_state()

get_result()

Retrieve the latest inference result from the queue.

Returns:

Name Type Description
dict

The latest result containing predictions and raw audio data.

Source code in src/maai/model.py
654
655
656
657
658
659
660
def get_result(self):
    """Retrieve the latest inference result from the queue.

    Returns:
        dict: The latest result containing predictions and raw audio data.
    """
    return self.result_dict_queue.get()

process(x1, x2)

Process a chunk of audio for both channels.

Parameters:

Name Type Description Default
x1 ndarray

Audio data chunk for channel 1.

required
x2 ndarray

Audio data chunk for channel 2.

required
Source code in src/maai/model.py
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
def process(self, x1, x2):
    """Process a chunk of audio for both channels.

    Args:
        x1 (np.ndarray): Audio data chunk for channel 1.
        x2 (np.ndarray): Audio data chunk for channel 2.
    """

    time_start = time.time()

    # Initialize buffer if empty
    if len(self.current_x1_audio) == 0:
        self.current_x1_audio = np.zeros(self.frame_contxt_padding, dtype=np.float32)
    if len(self.current_x2_audio) == 0:
        self.current_x2_audio = np.zeros(self.frame_contxt_padding, dtype=np.float32)

    # x1 = x1.astype(np.float32, copy=False)
    # x2 = x2.astype(np.float32, copy=False)

    # Add to buffer
    self.current_x1_audio = np.concatenate([self.current_x1_audio, x1])
    self.current_x2_audio = np.concatenate([self.current_x2_audio, x2])

    # Return if the buffer does not have enough length
    if len(self.current_x1_audio) < self.audio_frame_size:
        return

    # Extract data for inference
    x1_proc = self.current_x1_audio
    x2_proc = self.current_x2_audio

    x1_dist = x1_proc[self.frame_contxt_padding:]
    x2_dist = x2_proc[self.frame_contxt_padding:]

    with torch.inference_mode():
        # Create tensors more efficiently with specified dtype and device
        x1_ = torch.from_numpy(x1_proc).float().unsqueeze(0).unsqueeze(0)
        x2_ = torch.from_numpy(x2_proc).float().unsqueeze(0).unsqueeze(0)

        # Move to device only once
        if self.device != 'cpu':
            x1_ = x1_.to(self.device, non_blocking=True)
            x2_ = x2_.to(self.device, non_blocking=True)

        # try:
        e1, e2 = self.vap.encode_audio(x1_, x2_)
        # except RuntimeError as exc:
        #     short_chunk_error = (
        #         self.encoder_type == "mimi"
        #         and "Calculated padded input size per channel" in str(exc)
        #         and "Kernel size can't be greater than actual input size" in str(exc)
        #     )
        #     if short_chunk_error:
        #         self._increase_mimi_chunk_threshold(len(self.current_x1_audio))
        #         self.process_time_abs = time.time()
        #         return
        #     raise

        if e1.shape[1] == 0 or e2.shape[1] == 0:
            # if self.encoder_type == "mimi":
            #     self._increase_mimi_chunk_threshold(len(self.current_x1_audio))
            # self.process_time_abs = time.time()
            if self.frame_contxt_padding > 0:
                self.current_x1_audio = self.current_x1_audio[-self.frame_contxt_padding:].copy()
                self.current_x2_audio = self.current_x2_audio[-self.frame_contxt_padding:].copy()
            else:
                self.current_x1_audio = np.empty(0, dtype=np.float32)
                self.current_x2_audio = np.empty(0, dtype=np.float32)
            print("[Warning] No audio features extracted. Skipping this frame.")
            return

        # Skip the first Mimi encoder output to avoid the startup-only mismatch
        # between ONNX and PyTorch cache warmup behavior.
        if self._skip_first_encoder_output:
            self._skip_first_encoder_output = False
            self.process_time_abs = time.time()
            if self.frame_contxt_padding > 0:
                self.current_x1_audio = self.current_x1_audio[-self.frame_contxt_padding:].copy()
                self.current_x2_audio = self.current_x2_audio[-self.frame_contxt_padding:].copy()
            else:
                self.current_x1_audio = np.empty(0, dtype=np.float32)
                self.current_x2_audio = np.empty(0, dtype=np.float32)
            return

        # Full model
        if not self.use_kv_cache:

            self.e1_full.append(e1)
            self.e2_full.append(e2)

            # More efficient context management
            if len(self.e1_full) > self.audio_context_len:
                self.e1_full.pop(0)  # Remove from front instead of slicing
            if len(self.e2_full) > self.audio_context_len:
                self.e2_full.pop(0)

            x1_full_ = torch.cat(self.e1_full, dim=1)
            x2_full_ = torch.cat(self.e2_full, dim=1)

            # Move to device only if necessary
            if self.device != 'cpu':
                x1_full_ = x1_full_.to(self.device, non_blocking=True)
                x2_full_ = x2_full_.to(self.device, non_blocking=True)

            out, _ = self.vap.forward(x1_full_, x2_full_, cache=None)

        # User KV cache
        elif self.use_kv_cache:

            out, self.vap_cache = self.vap.forward(e1, e2, cache=self.vap_cache)

            ## Trim all cache data in self.vap_cache so that the second-to-last dimension is self.audio_context_len - 1
            if self.vap_cache is not None:
                new_cache = {}
                for key, (k_list, v_list) in self.vap_cache.items():
                    new_k_list = []
                    new_v_list = []
                    for t in k_list:
                        if isinstance(t, torch.Tensor) and t.dim() >= 3:
                            new_k_list.append(t[..., -(self.audio_context_len - 1) :, :])
                        else:
                            new_k_list.append(t)
                    for t in v_list:
                        if isinstance(t, torch.Tensor) and t.dim() >= 3:
                            new_v_list.append(t[..., -(self.audio_context_len - 1) :, :])
                        else:
                            new_v_list.append(t)
                    new_cache[key] = (new_k_list, new_v_list)
                self.vap_cache = new_cache

        # Pre-create result dict structure to avoid repeated key creation
        result_dict = {
            "t": time.time(),
            "x1": x1_dist.copy(),  # Only copy when necessary
            "x2": x2_dist.copy(),
        }

        # Use dictionary mapping for mode-specific outputs (faster than if-elif chain)
        mode_outputs = {
            "vap": lambda: {
                "p_now": out['p_now'],
                "p_future": out['p_future'],
                "vad": out['vad'],
                "p_bins": out['p_bins'],
                "p_bins_now": out['p_bins_now'],
                "p_bins_future": out['p_bins_future'],
            },
            "vap_mc": lambda: {
                "p_now": out['p_now'],
                "p_future": out['p_future'],
                "vad": out['vad'],
                "p_bins": out['p_bins'],
                "p_bins_now": out['p_bins_now'],
                "p_bins_future": out['p_bins_future'],
            },
            "vap_prompt": lambda: {
                "p_now": out['p_now'],
                "p_future": out['p_future'],
                "vad": out['vad']
            },
            "bc": lambda: {
                "p_bc": out['p_bc'],
                "p_bc_detect": out['p_bc_detect']
            },
            "bc_2type": lambda: {
                "p_bc_react": out['p_bc_react'],
                "p_bc_emo": out['p_bc_emo']
            },
            "nod": lambda: {
                "p_bc": out['p_bc'],
                "p_nod_short": out['p_nod_short'],
                "p_nod_long": out['p_nod_long'],
                "p_nod_long_p": out['p_nod_long_p']
            },
            "nod_para": lambda: {
                "p_nod": out["p_nod"],
                "nod_repetitions": out["nod_repetitions"],
                "nod_repetitions_pred": out["nod_repetitions_pred"],
                "nod_range": out["nod_range"],
                "nod_speed": out["nod_speed"],
                "nod_swing_up": out["nod_swing_up"],
                "nod_swing_up_pred": out["nod_swing_up_pred"],
            },
        }

        # Get mode-specific outputs
        if self.mode in mode_outputs:
            _out = mode_outputs[self.mode]()
            if not self.return_p_bins and self.mode in ("vap", "vap_mc"):
                for _k in ("p_bins", "p_bins_now", "p_bins_future"):
                    _out.pop(_k, None)
            result_dict.update(_out)

        self.result_dict_queue.put(result_dict)

        time_process = time.time() - time_start
        self.list_process_time_context.append(time_process)

        # Performance monitoring (unchanged for clarity)
        if len(self.list_process_time_context) > self.CALC_PROCESS_TIME_INTERVAL:
            ave_proc_time = np.mean(self.list_process_time_context)  # np.mean is faster than np.average
            num_process_frame = len(self.list_process_time_context) / (time.time() - self.last_interval_time)
            self.last_interval_time = time.time()

            perf_message = f'[{self.mode}] Average processing time: {ave_proc_time:.5f} [sec], #process/sec: {num_process_frame:.3f}'
            if self.encoder_type == "mimi":
                perf_message += f', chunk_samples: {self.audio_frame_size}'
            print(perf_message)
            self.list_process_time_context.clear()  # clear() is faster than = []

        self.process_time_abs = time.time()

    # Keep only the last samples in the buffer (use views for efficiency)
    if self.frame_contxt_padding > 0:
        self.current_x1_audio = self.current_x1_audio[-self.frame_contxt_padding:].copy()
        self.current_x2_audio = self.current_x2_audio[-self.frame_contxt_padding:].copy()
    else:
        self.current_x1_audio = np.empty(0, dtype=np.float32)
        self.current_x2_audio = np.empty(0, dtype=np.float32)

reset_runtime_state()

Reset the internal audio buffers and cache states.

Source code in src/maai/model.py
334
335
336
337
338
339
340
341
342
343
344
345
346
def reset_runtime_state(self):
    """Reset the internal audio buffers and cache states."""
    self.current_x1_audio = []
    self.current_x2_audio = []
    self.e1_full = []
    self.e2_full = []
    self.vap_cache = None
    self._skip_first_encoder_output = bool(self.encoder_type == "mimi" and self._use_mimi_onnx)

    for encoder_name in ["encoder1", "encoder2"]:
        encoder = getattr(self.vap, encoder_name, None)
        if encoder is not None and hasattr(encoder, "reset_streaming_state"):
            encoder.reset_streaming_state()

set_prompt_ch1(prompt)

Set the prompt text for speaker 1. This method is only available for the 'vap_prompt' mode.

Parameters:

Name Type Description Default
prompt str

The prompt text for speaker 1.

required
Source code in src/maai/model.py
662
663
664
665
666
667
668
669
670
671
672
673
def set_prompt_ch1(self, prompt: str):
    """
    Set the prompt text for speaker 1. This method is only available for the 'vap_prompt' mode.

    Args:
        prompt (str): The prompt text for speaker 1.
    """

    if self.mode != "vap_prompt":
        raise ValueError("This method is only available for the 'vap_prompt' mode.")

    self.vap.set_prompt_ch1(prompt, self.device)

set_prompt_ch2(prompt)

Set the prompt text for speaker 2. This method is only available for the 'vap_prompt' mode.

Parameters:

Name Type Description Default
prompt str

The prompt text for speaker 2.

required
Source code in src/maai/model.py
675
676
677
678
679
680
681
682
683
684
685
686
def set_prompt_ch2(self, prompt: str):
    """
    Set the prompt text for speaker 2. This method is only available for the 'vap_prompt' mode.

    Args:
        prompt (str): The prompt text for speaker 2.
    """

    if self.mode != "vap_prompt":
        raise ValueError("This method is only available for the 'vap_prompt' mode.")

    self.vap.set_prompt_ch2(prompt, self.device)

start()

Start the background audio fetching and processing thread.

Source code in src/maai/model.py
394
395
396
397
398
399
400
401
402
403
404
405
406
def start(self):
    """Start the background audio fetching and processing thread."""

    self.reset_runtime_state()

    self.mic1.start()
    self.mic2.start()
    self._stop_event.clear()
    # Queue を空にしてからスレッド起動(スレッドに古いデータが届かないよう先にクリア)
    self._mic1_queue.queue.clear()
    self._mic2_queue.queue.clear()
    self._worker_thread = threading.Thread(target=self.worker, daemon=True)
    self._worker_thread.start()

stop(wait=True, timeout=2.0)

Safely stop the background processing thread. Args: wait (bool): If True, wait for the thread to finish. timeout (float): Max seconds to wait when joining.

Source code in src/maai/model.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
def stop(self, wait: bool = True, timeout: float = 2.0):
    """
    Safely stop the background processing thread.
    Args:
        wait (bool): If True, wait for the thread to finish.
        timeout (float): Max seconds to wait when joining.
    """
    self._stop_event.set()
    # Unblock blocking gets by pushing sentinels
    try:
        self._mic1_queue.put(None)
        self._mic2_queue.put(None)
    except Exception:
        pass
    if wait and self._worker_thread is not None and self._worker_thread.is_alive():
        self._worker_thread.join(timeout=timeout)

    # Best-effort queue cleanup
    try:
        self._mic1_queue.queue.clear()
        self._mic2_queue.queue.clear()
    except Exception:
        pass

    self.reset_runtime_state()

worker()

Background loop to fetch audio from queues and run inference.

Source code in src/maai/model.py
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
def worker(self):
    """Background loop to fetch audio from queues and run inference."""

    # Clear the queues at the start
    # This is to ensure that the queues are empty before starting the processing loop
    self._mic1_queue.queue.clear()
    self._mic2_queue.queue.clear()

    while not self._stop_event.is_set():
        x1 = self.mic1.get_audio_data(self._mic1_queue)
        x2 = self.mic2.get_audio_data(self._mic2_queue)

        if self._stop_event.is_set() or x1 is None or x2 is None:
            break

        self.process(x1, x2)

        # Clear the queues if they are too large
        if self._mic1_queue.qsize() > 100:
            self._mic1_queue.queue.clear()
            print("[Warning] Audio queue (channel 1) overflow detected. Clearing audio queues.")
        if self._mic2_queue.qsize() > 100:
            self._mic2_queue.queue.clear()
            print("[Warning] Audio queue (channel 2) overflow detected. Clearing audio queues.")

MaaiMultiple

Run several Maai models in parallel that share a single audio encoder.

This is useful when you want to combine, for example, the turn-taking (vap) model with the backchannel (bc) model and the nodding (nod) model on the same input audio. A naive approach would build one Maai instance per model and run the (relatively expensive) audio encoder once for each instance. MaaiMultiple instead encodes the audio once per frame and feeds the encoded features into every sub-model.

All sub-models must share the same encoder configuration: model_type, frame_rate, context_len_sec, device and the mimi_* parameters. Per-model differences allowed in configs are mode, lang, local_model, return_p_bins and an optional label used as the result key.

Each call to :meth:get_result returns a single dict whose top level contains shared fields t, x1, x2 plus one nested dict per sub-model keyed by its label (or its mode if no label is given).

Source code in src/maai/model.py
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
class MaaiMultiple:
    """
    Run several Maai models in parallel that share a single audio encoder.

    This is useful when you want to combine, for example, the turn-taking
    (``vap``) model with the backchannel (``bc``) model and the nodding
    (``nod``) model on the same input audio. A naive approach would build
    one ``Maai`` instance per model and run the (relatively expensive) audio
    encoder once for each instance. ``MaaiMultiple`` instead encodes the
    audio once per frame and feeds the encoded features into every sub-model.

    All sub-models must share the same encoder configuration:
    ``model_type``, ``frame_rate``, ``context_len_sec``, ``device`` and the
    ``mimi_*`` parameters. Per-model differences allowed in ``configs`` are
    ``mode``, ``lang``, ``local_model``, ``return_p_bins`` and an optional
    ``label`` used as the result key.

    Each call to :meth:`get_result` returns a single ``dict`` whose top level
    contains shared fields ``t``, ``x1``, ``x2`` plus one nested ``dict``
    per sub-model keyed by its label (or its mode if no label is given).
    """

    CALC_PROCESS_TIME_INTERVAL = 100

    # Modes whose ``encode_audio`` swaps audio1/audio2 before model forward.
    _SWAP_MODES = {"bc", "bc_2type", "nod", "nod_para"}

    def __init__(
        self,
        configs: list,
        audio_ch1: Base,
        audio_ch2: Base,
        frame_rate: float = 10,
        context_len_sec: int = 20,
        device: str = "cpu",
        cpc_model: str = os.path.expanduser("~/.cache/cpc/60k_epoch4-d0f474de.pt"),
        model_type: str = "normal",
        mimi_model_name: str = "kyutai/mimi",
        use_mimi_onnx: bool = True,
        mimi_onnx_precision: str = "fp32",
        mimi_onnx_fp32_path: str | None = None,
        mimi_onnx_fp32_meta_path: str | None = None,
        mimi_onnx_int8_path: str | None = None,
        mimi_onnx_int8_meta_path: str | None = None,
        mimi_local_onnx_fp32_path: str | None = None,
        mimi_local_onnx_fp32_meta_path: str | None = None,
        mimi_local_onnx_int8_path: str | None = None,
        mimi_local_onnx_int8_meta_path: str | None = None,
        mimi_onnx_cpu_intra_threads: int | None = None,
        mimi_onnx_cpu_inter_threads: int | None = None,
        cache_dir: str = None,
        force_download: bool = False,
        use_kv_cache: bool = True,
    ):
        if not configs:
            raise ValueError("MaaiMultiple requires at least one model config.")

        shared_kwargs = dict(
            audio_ch1=audio_ch1,
            audio_ch2=audio_ch2,
            frame_rate=frame_rate,
            context_len_sec=context_len_sec,
            device=device,
            cpc_model=cpc_model,
            model_type=model_type,
            mimi_model_name=mimi_model_name,
            use_mimi_onnx=use_mimi_onnx,
            mimi_onnx_precision=mimi_onnx_precision,
            mimi_onnx_fp32_path=mimi_onnx_fp32_path,
            mimi_onnx_fp32_meta_path=mimi_onnx_fp32_meta_path,
            mimi_onnx_int8_path=mimi_onnx_int8_path,
            mimi_onnx_int8_meta_path=mimi_onnx_int8_meta_path,
            mimi_local_onnx_fp32_path=mimi_local_onnx_fp32_path,
            mimi_local_onnx_fp32_meta_path=mimi_local_onnx_fp32_meta_path,
            mimi_local_onnx_int8_path=mimi_local_onnx_int8_path,
            mimi_local_onnx_int8_meta_path=mimi_local_onnx_int8_meta_path,
            mimi_onnx_cpu_intra_threads=mimi_onnx_cpu_intra_threads,
            mimi_onnx_cpu_inter_threads=mimi_onnx_cpu_inter_threads,
            cache_dir=cache_dir,
            force_download=force_download,
            use_kv_cache=use_kv_cache,
        )

        self.sub_maais: list[Maai] = []
        self.labels: list[str] = []
        seen_labels: set[str] = set()
        for cfg in configs:
            if "mode" not in cfg or "lang" not in cfg:
                raise ValueError("Each entry of `configs` must contain 'mode' and 'lang'.")
            label = cfg.get("label", cfg["mode"])
            if label in seen_labels:
                raise ValueError(
                    f"Duplicate label '{label}'. Provide a unique 'label' field in each config."
                )
            seen_labels.add(label)
            self.labels.append(label)
            sub = Maai(
                mode=cfg["mode"],
                lang=cfg["lang"],
                local_model=cfg.get("local_model"),
                return_p_bins=cfg.get("return_p_bins", False),
                **shared_kwargs,
            )
            self.sub_maais.append(sub)

        primary = self.sub_maais[0]

        # Mic configuration. Each Maai sub-instance subscribes a queue from
        # the input source in __init__; we keep only the primary's queues and
        # detach the rest so the audio source does not push frames into queues
        # that nobody drains.
        self.mic1 = audio_ch1
        self.mic2 = audio_ch2
        self._mic1_queue = primary._mic1_queue
        self._mic2_queue = primary._mic2_queue
        for sub in self.sub_maais[1:]:
            self._unsubscribe(self.mic1, sub._mic1_queue)
            self._unsubscribe(self.mic2, sub._mic2_queue)

        # Shared frame parameters (validated via primary; all sub-Maais use
        # the same encoder configuration so these match across instances).
        self.device = primary.device
        self.encoder_type = primary.encoder_type
        self._use_mimi_onnx = primary._use_mimi_onnx
        self.frame_rate = float(frame_rate)
        self.audio_contenxt_lim_sec = context_len_sec
        self.audio_context_len = primary.audio_context_len
        self.sampling_rate = primary.sampling_rate
        self.frame_contxt_padding = primary.frame_contxt_padding
        self.audio_frame_size = primary.audio_frame_size
        self.use_kv_cache = bool(use_kv_cache)

        # Shared audio buffers and (when KV cache is disabled) shared encoded
        # context. These are populated by the worker on each frame.
        self.current_x1_audio = []
        self.current_x2_audio = []
        self.eA_full: list[torch.Tensor] = []
        self.eB_full: list[torch.Tensor] = []

        # One result queue, populated with a combined dict per frame.
        self.result_dict_queue: queue.Queue = queue.Queue()

        # Performance monitoring.
        self.list_process_time_context: list[float] = []
        self.last_interval_time = time.time()
        self.process_time_abs = -1

        # Threading.
        self._stop_event = threading.Event()
        self._worker_thread: threading.Thread | None = None

        self.reset_runtime_state()

    @staticmethod
    def _unsubscribe(source: Base, q):
        try:
            with source._lock:
                if q in source._subscriber_queues:
                    source._subscriber_queues.remove(q)
        except Exception:
            pass

    def reset_runtime_state(self):
        self.current_x1_audio = []
        self.current_x2_audio = []
        self.eA_full = []
        self.eB_full = []
        self._skip_first_encoder_output = bool(
            self.encoder_type == "mimi" and self._use_mimi_onnx
        )

        primary_vap = self.sub_maais[0].vap
        for encoder_name in ["encoder1", "encoder2"]:
            encoder = getattr(primary_vap, encoder_name, None)
            if encoder is not None and hasattr(encoder, "reset_streaming_state"):
                encoder.reset_streaming_state()

        for sub in self.sub_maais:
            sub.vap_cache = None
            sub.e1_full = []
            sub.e2_full = []

    def worker(self):
        self._mic1_queue.queue.clear()
        self._mic2_queue.queue.clear()

        while not self._stop_event.is_set():
            x1 = self.mic1.get_audio_data(self._mic1_queue)
            x2 = self.mic2.get_audio_data(self._mic2_queue)

            if self._stop_event.is_set() or x1 is None or x2 is None:
                break

            self.process(x1, x2)

            if self._mic1_queue.qsize() > 100:
                self._mic1_queue.queue.clear()
                print("[Warning] Audio queue (channel 1) overflow detected. Clearing audio queues.")
            if self._mic2_queue.qsize() > 100:
                self._mic2_queue.queue.clear()
                print("[Warning] Audio queue (channel 2) overflow detected. Clearing audio queues.")

    def start(self):
        self.reset_runtime_state()

        self.mic1.start()
        self.mic2.start()
        self._stop_event.clear()
        # Queue を空にしてからスレッド起動(スレッドに古いデータが届かないよう先にクリア)
        self._mic1_queue.queue.clear()
        self._mic2_queue.queue.clear()
        self._worker_thread = threading.Thread(target=self.worker, daemon=True)
        self._worker_thread.start()

    def stop(self, wait: bool = True, timeout: float = 2.0):
        self._stop_event.set()
        try:
            self._mic1_queue.put(None)
            self._mic2_queue.put(None)
        except Exception:
            pass
        if wait and self._worker_thread is not None and self._worker_thread.is_alive():
            self._worker_thread.join(timeout=timeout)

        try:
            self._mic1_queue.queue.clear()
            self._mic2_queue.queue.clear()
        except Exception:
            pass

        self.reset_runtime_state()

    def _trim_audio_buffers(self):
        if self.frame_contxt_padding > 0:
            self.current_x1_audio = self.current_x1_audio[-self.frame_contxt_padding:].copy()
            self.current_x2_audio = self.current_x2_audio[-self.frame_contxt_padding:].copy()
        else:
            self.current_x1_audio = np.empty(0, dtype=np.float32)
            self.current_x2_audio = np.empty(0, dtype=np.float32)

    @staticmethod
    def _trim_kv_cache(cache: dict, ctx_len: int) -> dict:
        new_cache = {}
        for key, (k_list, v_list) in cache.items():
            new_k_list = []
            new_v_list = []
            for t in k_list:
                if isinstance(t, torch.Tensor) and t.dim() >= 3:
                    new_k_list.append(t[..., -(ctx_len - 1):, :])
                else:
                    new_k_list.append(t)
            for t in v_list:
                if isinstance(t, torch.Tensor) and t.dim() >= 3:
                    new_v_list.append(t[..., -(ctx_len - 1):, :])
                else:
                    new_v_list.append(t)
            new_cache[key] = (new_k_list, new_v_list)
        return new_cache

    def process(self, x1, x2):
        time_start = time.time()

        if len(self.current_x1_audio) == 0:
            self.current_x1_audio = np.zeros(self.frame_contxt_padding, dtype=np.float32)
        if len(self.current_x2_audio) == 0:
            self.current_x2_audio = np.zeros(self.frame_contxt_padding, dtype=np.float32)

        self.current_x1_audio = np.concatenate([self.current_x1_audio, x1])
        self.current_x2_audio = np.concatenate([self.current_x2_audio, x2])

        if len(self.current_x1_audio) < self.audio_frame_size:
            return

        x1_proc = self.current_x1_audio
        x2_proc = self.current_x2_audio
        x1_dist = x1_proc[self.frame_contxt_padding:]
        x2_dist = x2_proc[self.frame_contxt_padding:]

        with torch.inference_mode():
            x1_t = torch.from_numpy(x1_proc).float().unsqueeze(0).unsqueeze(0)
            x2_t = torch.from_numpy(x2_proc).float().unsqueeze(0).unsqueeze(0)
            if self.device != "cpu":
                x1_t = x1_t.to(self.device, non_blocking=True)
                x2_t = x2_t.to(self.device, non_blocking=True)

            # Shared encoding step. We extract the shared features (before model-specific downsample)
            # using the primary model's encoder. The state/KV cache is maintained here.
            primary_vap = self.sub_maais[0].vap

            if self.encoder_type == "mimi":
                eA_shared, input_num_samples_A = primary_vap.encoder1.forward_shared(x1_t)
                eB_shared, input_num_samples_B = primary_vap.encoder2.forward_shared(x2_t)
            else:
                eA_shared = primary_vap.encoder1.forward_shared(x1_t)
                eB_shared = primary_vap.encoder2.forward_shared(x2_t)

            if eA_shared.shape[1] == 0 or eB_shared.shape[1] == 0:
                self._trim_audio_buffers()
                print("[Warning] No audio features extracted. Skipping this frame.")
                return

            # Skip the first Mimi encoder output to avoid the startup-only
            # mismatch between ONNX and PyTorch cache warmup behavior, just
            # like Maai.process does.
            if self._skip_first_encoder_output:
                self._skip_first_encoder_output = False
                self.process_time_abs = time.time()
                self._trim_audio_buffers()
                return

            # If KV cache is disabled, maintain a shared rolling context of
            # encoded features and feed each sub-model with the full window.
            if not self.use_kv_cache:
                self.eA_full.append(eA_shared)
                self.eB_full.append(eB_shared)
                if len(self.eA_full) > self.audio_context_len:
                    self.eA_full.pop(0)
                if len(self.eB_full) > self.audio_context_len:
                    self.eB_full.pop(0)
                eA_in = torch.cat(self.eA_full, dim=1)
                eB_in = torch.cat(self.eB_full, dim=1)
            else:
                eA_in = eA_shared
                eB_in = eB_shared

            results_combined: dict = {
                "t": time.time(),
                "x1": x1_dist.copy(),
                "x2": x2_dist.copy(),
            }

            for label, sub in zip(self.labels, self.sub_maais):
                # Reproduce the per-mode swap that lives in each model's
                # encode_audio: bc/bc_2type/nod/nod_para want the swapped
                # (user, system) ordering, the others want the natural one.
                if sub.mode in self._SWAP_MODES:
                    e1_shared, e2_shared = eB_in, eA_in
                else:
                    e1_shared, e2_shared = eA_in, eB_in

                # Apply the model-specific downsampling layer
                if self.encoder_type == "mimi":
                    if sub.mode in self._SWAP_MODES:
                        n1, n2 = input_num_samples_B, input_num_samples_A
                    else:
                        n1, n2 = input_num_samples_A, input_num_samples_B
                    e1 = sub.vap.encoder1.forward_specific(e1_shared, input_num_samples=n1)
                    e2 = sub.vap.encoder2.forward_specific(e2_shared, input_num_samples=n2)
                else:
                    e1 = sub.vap.encoder1.forward_specific(e1_shared)
                    e2 = sub.vap.encoder2.forward_specific(e2_shared)

                # Apply each sub-model's decrease_dimension here (most models
                # apply it inside encode_audio). nod_para applies projections
                # internally inside its forward, so leave it alone.
                if sub.mode != "nod_para" and hasattr(sub.vap, "decrease_dimension"):
                    e1 = sub.vap.decrease_dimension(e1)
                    e2 = sub.vap.decrease_dimension(e2)

                if self.use_kv_cache:
                    out, sub.vap_cache = sub.vap.forward(e1, e2, cache=sub.vap_cache)
                    if sub.vap_cache is not None:
                        sub.vap_cache = self._trim_kv_cache(sub.vap_cache, self.audio_context_len)
                else:
                    out, _ = sub.vap.forward(e1, e2, cache=None)

                results_combined[label] = self._extract_outputs(
                    sub.mode, out, sub.return_p_bins
                )

            self.result_dict_queue.put(results_combined)

            time_process = time.time() - time_start
            self.list_process_time_context.append(time_process)

            if len(self.list_process_time_context) > self.CALC_PROCESS_TIME_INTERVAL:
                ave_proc_time = np.mean(self.list_process_time_context)
                num_process_frame = (
                    len(self.list_process_time_context)
                    / (time.time() - self.last_interval_time)
                )
                self.last_interval_time = time.time()

                modes = ",".join(self.labels)
                msg = (
                    f"[multi:{modes}] Average processing time: {ave_proc_time:.5f} [sec], "
                    f"#process/sec: {num_process_frame:.3f}"
                )
                if self.encoder_type == "mimi":
                    msg += f", chunk_samples: {self.audio_frame_size}"
                print(msg)
                self.list_process_time_context.clear()

            self.process_time_abs = time.time()

        self._trim_audio_buffers()

    @staticmethod
    def _extract_outputs(mode: str, out: dict, return_p_bins: bool) -> dict:
        if mode in ("vap", "vap_mc"):
            d = {
                "p_now": out["p_now"],
                "p_future": out["p_future"],
                "vad": out["vad"],
                "p_bins": out["p_bins"],
                "p_bins_now": out["p_bins_now"],
                "p_bins_future": out["p_bins_future"],
            }
            if not return_p_bins:
                for k in ("p_bins", "p_bins_now", "p_bins_future"):
                    d.pop(k, None)
            return d
        if mode == "vap_prompt":
            return {
                "p_now": out["p_now"],
                "p_future": out["p_future"],
                "vad": out["vad"],
            }
        if mode == "bc":
            return {
                "p_bc": out["p_bc"],
                "p_bc_detect": out["p_bc_detect"],
            }
        if mode == "bc_2type":
            return {
                "p_bc_react": out["p_bc_react"],
                "p_bc_emo": out["p_bc_emo"],
            }
        if mode == "nod":
            return {
                "p_bc": out["p_bc"],
                "p_nod_short": out["p_nod_short"],
                "p_nod_long": out["p_nod_long"],
                "p_nod_long_p": out["p_nod_long_p"],
            }
        if mode == "nod_para":
            return {
                "p_nod": out["p_nod"],
                "nod_repetitions": out["nod_repetitions"],
                "nod_repetitions_pred": out["nod_repetitions_pred"],
                "nod_range": out["nod_range"],
                "nod_speed": out["nod_speed"],
                "nod_swing_up": out["nod_swing_up"],
                "nod_swing_up_pred": out["nod_swing_up_pred"],
            }
        return {}

    def get_result(self):
        return self.result_dict_queue.get()

    def get_sub_maai(self, label: str) -> Maai:
        """Return the underlying ``Maai`` instance registered under ``label``."""
        for lbl, sub in zip(self.labels, self.sub_maais):
            if lbl == label:
                return sub
        raise KeyError(f"No sub-model with label '{label}'. Available: {self.labels}")

    def set_prompt_ch1(self, prompt: str, label: str | None = None):
        """Set channel-1 prompt on every ``vap_prompt`` sub-model (or only on ``label``)."""
        applied = False
        for lbl, sub in zip(self.labels, self.sub_maais):
            if sub.mode == "vap_prompt" and (label is None or lbl == label):
                sub.set_prompt_ch1(prompt)
                applied = True
        if label is not None and not applied:
            raise ValueError(f"No 'vap_prompt' sub-model found for label '{label}'.")

    def set_prompt_ch2(self, prompt: str, label: str | None = None):
        """Set channel-2 prompt on every ``vap_prompt`` sub-model (or only on ``label``)."""
        applied = False
        for lbl, sub in zip(self.labels, self.sub_maais):
            if sub.mode == "vap_prompt" and (label is None or lbl == label):
                sub.set_prompt_ch2(prompt)
                applied = True
        if label is not None and not applied:
            raise ValueError(f"No 'vap_prompt' sub-model found for label '{label}'.")

get_sub_maai(label)

Return the underlying Maai instance registered under label.

Source code in src/maai/model.py
1139
1140
1141
1142
1143
1144
def get_sub_maai(self, label: str) -> Maai:
    """Return the underlying ``Maai`` instance registered under ``label``."""
    for lbl, sub in zip(self.labels, self.sub_maais):
        if lbl == label:
            return sub
    raise KeyError(f"No sub-model with label '{label}'. Available: {self.labels}")

set_prompt_ch1(prompt, label=None)

Set channel-1 prompt on every vap_prompt sub-model (or only on label).

Source code in src/maai/model.py
1146
1147
1148
1149
1150
1151
1152
1153
1154
def set_prompt_ch1(self, prompt: str, label: str | None = None):
    """Set channel-1 prompt on every ``vap_prompt`` sub-model (or only on ``label``)."""
    applied = False
    for lbl, sub in zip(self.labels, self.sub_maais):
        if sub.mode == "vap_prompt" and (label is None or lbl == label):
            sub.set_prompt_ch1(prompt)
            applied = True
    if label is not None and not applied:
        raise ValueError(f"No 'vap_prompt' sub-model found for label '{label}'.")

set_prompt_ch2(prompt, label=None)

Set channel-2 prompt on every vap_prompt sub-model (or only on label).

Source code in src/maai/model.py
1156
1157
1158
1159
1160
1161
1162
1163
1164
def set_prompt_ch2(self, prompt: str, label: str | None = None):
    """Set channel-2 prompt on every ``vap_prompt`` sub-model (or only on ``label``)."""
    applied = False
    for lbl, sub in zip(self.labels, self.sub_maais):
        if sub.mode == "vap_prompt" and (label is None or lbl == label):
            sub.set_prompt_ch2(prompt)
            applied = True
    if label is not None and not applied:
        raise ValueError(f"No 'vap_prompt' sub-model found for label '{label}'.")