From 9f3114c8dca108be5bbe120ca547241786e5057c Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Mon, 2 Jun 2025 13:26:10 +0530 Subject: [PATCH 01/14] adding multilingual support --- .DS_Store | Bin 0 -> 6148 bytes configs/config.yml | 24 +- configs/languages.json | 36 +++ convert_to_pb.py | 35 ++- models/.DS_Store | Bin 0 -> 6148 bytes multilingual_dataset.py | 124 +++++++++ predict_by_pb.py | 94 +++++-- requirements.txt | 20 +- saved_models/.DS_Store | Bin 0 -> 6148 bytes saved_models/lang14/.DS_Store | Bin 0 -> 6148 bytes saved_models/lang14/pb/.DS_Store | Bin 0 -> 6148 bytes saved_models/lang14/pb/2/.DS_Store | Bin 0 -> 6148 bytes test_tf.py | 13 + train.py | 399 +++++++++++++---------------- train_multilingual.py | 147 +++++++++++ 15 files changed, 627 insertions(+), 265 deletions(-) create mode 100644 .DS_Store create mode 100644 configs/languages.json create mode 100644 models/.DS_Store create mode 100644 multilingual_dataset.py create mode 100644 saved_models/.DS_Store create mode 100644 saved_models/lang14/.DS_Store create mode 100644 saved_models/lang14/pb/.DS_Store create mode 100644 saved_models/lang14/pb/2/.DS_Store create mode 100644 test_tf.py create mode 100644 train_multilingual.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..69826a8ffb3bc1bf656d0348be3c138467bfb61f GIT binary patch literal 6148 zcmeHK!AiqG5S?vnn^1)w6nb3nTCmkx6fdFHA26Z^m6(vC!8BW%wmFnS&iX@siQnVQ z?pCN(yot0sF#C3AXArTQJKn(yaRKkLd!#6_vq$^U;mJ%Y9&xj!nvUC{4 zB3SV@hX2R_eY-LYKtKq6SpI%*!7z@}akKf(3dPb^xnfnUs&(%^iOe1Q<4M}}N4K=P z5+V*3b3eF_`k7bTJr~K?kCOgK21H>WUG8q8BotX!Op-8_v98*%Y|Hj)_1Ua*+-dMu z`()nWv(v+7gSU@5^SN#9?H`<7de8A75ig2JhX0(D9gPck#mQ1mQ+JfaB6&nR#cV?g zK1?y+5WF?a{m1;i@UY;^pkcr;uz3vV^V6zsUg{=e!+>Gn7Y68j;HZR-#$2I1I&hFJ z03ux@rQn=;35www9gVp{+(BU~6j6mT-C{5mj&`p79F4g`6%I@{A56cQ=?;a-x8w7< zbO+`rG^1g_FtE%(Q7vm!|1W;-|CfWz$}nIUSStotsps~(n3AroE6Gu<<){y+B;=PX m{0zZ~EXC+crMQkN1?`+Bh>pfwAzDz(ML^PE2E)LgGVlRlxQ^EV literal 0 HcmV?d00001 diff --git a/configs/config.yml b/configs/config.yml index 9004160..37c69da 100644 --- a/configs/config.yml +++ b/configs/config.yml @@ -8,6 +8,8 @@ speech_config: normalize_signal: True normalize_feature: True normalize_per_feature: False + use_fma: True + use_neon: False model_config: name: acrnn @@ -16,20 +18,23 @@ model_config: kernel_size: [[11,5],[11,5],[11,5]] rnn_cell: 256 seq_mask: True + num_languages: 100 dataset_config: vocabulary: vocab/vocab.txt data_path: ./data/wavs/ - corpus_name: ./data/demo_txt/demo + corpus_name: ./data/multilingual/ file_nums: 1 max_audio_length: 2000 shuffle_size: 1200 data_length: None suffix: .txt - load_type: txt + load_type: multilingual train: train dev: dev test: test + languages_file: configs/languages.json + max_samples_per_language: 10000 optimizer_config: init_steps: 0 @@ -38,12 +43,15 @@ optimizer_config: beta1: 0.9 beta2: 0.999 epsilon: 1e-9 + use_mixed_precision: True running_config: - prefetch: False - load_weights: ./saved_weights/20230228-084356/last/model + prefetch: True + load_weights: ./saved_weights/multilingual/last/model num_epochs: 100 - batch_size: 1 - train_steps: 50 - dev_steps: 10 - test_steps: 10 \ No newline at end of file + batch_size: 32 + train_steps: 1000 + dev_steps: 100 + test_steps: 100 + save_interval: 5 + eval_interval: 1 \ No newline at end of file diff --git a/configs/languages.json b/configs/languages.json new file mode 100644 index 0000000..78986ba --- /dev/null +++ b/configs/languages.json @@ -0,0 +1,36 @@ +{ + "supported_languages": [ + "en", "es", "fr", "de", "it", "pt", "nl", "pl", "ru", "uk", + "ar", "hi", "bn", "ta", "te", "mr", "ur", "fa", "tr", "he", + "th", "vi", "id", "ms", "fil", "ja", "ko", "zh", "yue", "sw", + "af", "am", "az", "be", "bg", "ca", "cs", "da", "el", "et", + "eu", "fi", "ga", "gl", "gu", "ha", "hr", "hu", "hy", "ig", + "is", "ka", "kk", "km", "kn", "ky", "lb", "lg", "lt", "lv", + "mk", "ml", "mn", "mt", "my", "ne", "no", "pa", "ps", "ro", + "si", "sk", "sl", "sn", "so", "sq", "sr", "sv", "tg", "tk", + "uz", "wo", "xh", "yi", "yo", "zu", "as", "bho", "doi", "mai", + "or", "raj", "sa", "sd", "cy", "fo", "gd", "kw", "fy", "rm" + ], + "language_names": { + "en": "English", + "es": "Spanish", + "fr": "French", + "de": "German", + "it": "Italian", + "pt": "Portuguese", + "nl": "Dutch", + "pl": "Polish", + "ru": "Russian", + "uk": "Ukrainian", + "ar": "Arabic", + "hi": "Hindi", + "bn": "Bengali", + "ta": "Tamil", + "te": "Telugu", + "mr": "Marathi", + "ur": "Urdu", + "fa": "Persian", + "tr": "Turkish", + "he": "Hebrew" + } +} \ No newline at end of file diff --git a/convert_to_pb.py b/convert_to_pb.py index 39f6acf..17d4705 100644 --- a/convert_to_pb.py +++ b/convert_to_pb.py @@ -20,25 +20,38 @@ vocab = Vocab(vocabulary) -# build model -model=Model(**config.model_config,vocab_size=len(vocab.token_list)) +# Build model +model = Model(**config.model_config, vocab_size=len(vocab.token_list)) model.init_build([None, config.speech_config['num_feature_bins']]) model.load_weights(weights_dir + "last/model") model.add_featurizers(speech_featurizer) - version = 2 -#****convert to pb****** -tf.saved_model.save(model, "saved_models/lang14/pb/" + str(version)) -print('convert to pb model successful') -#****convert to serving****** +# Convert to SavedModel format with signatures +@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)]) +def predict_fn(signal): + output, prob = model.predict_pb(signal) + return {"output_0": output, "output_1": prob} + +# Save model with proper signatures tf.saved_model.save( model, - "./saved_models/lang14/serving/"+str(version), + f"saved_models/lang14/pb/{version}", signatures={ - 'predict_pb': model.predict_pb - } + "serving_default": predict_fn, + "predict_pb": model.predict_pb + } ) +print('Model converted to SavedModel format successfully') -print('convert to serving model successful') +# Save model for TensorFlow Serving +tf.saved_model.save( + model, + f"saved_models/lang14/serving/{version}", + signatures={ + "serving_default": predict_fn, + "predict_pb": model.predict_pb + } +) +print('Model converted for TensorFlow Serving successfully') diff --git a/models/.DS_Store b/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..124a457201c3ca6527eb310ffa9cea25791e661f GIT binary patch literal 6148 zcmeHK%}T>S5Z-NTn^J@x6nb3nTCmkx6fYsx7cim+m70*C!I&*cYmidNSzpK}@p+ut z-5i1iZz6UEcE8#A+0A^A{b7u8XAuk;>oUeHXowt@HG<}~u9^u(R5N~qs-m8H(z9qm)_v$Qp(sb zbuQJ9w%XHfZ@C;Fjr(FaI$rg~ z^5kI97o)@RYSp%O_x4XOrcd#6BHuKT95`39ZLoxQP%1UOxM>p0TF)=R539?4E(?V?hgVQqHVCysJ0I1@cNAY1|kaR_?AEv25p0dM(}`e zoeHQ^xp`u6oep+k;%tM3MxD;MS{ddsE0>QKu2u)TP~nW*8mT7+h=D2tbv^9j`F{?- z%-TnOHHAjR05R~-7~qYGJ8_^WbGCjf56@Zw?H(Em#ucc5fL^%-zySA=o^ooxKpo<2 XgM~(%1??&wkS+p>5bB75Utr)1LeNZA literal 0 HcmV?d00001 diff --git a/multilingual_dataset.py b/multilingual_dataset.py new file mode 100644 index 0000000..2a620e1 --- /dev/null +++ b/multilingual_dataset.py @@ -0,0 +1,124 @@ +import os +import json +import numpy as np +import pandas as pd +from tqdm import tqdm +from typing import List, Dict, Optional +from datasets import load_dataset, Audio +from transformers import Wav2Vec2Processor +from featurizers.speech_featurizers import SpeechFeaturizer +from configs.config import Config +from vocab.vocab import Vocab + +class MultilingualDataset: + def __init__( + self, + config: Config, + languages: List[str], + vocab: Vocab, + speech_featurizer: SpeechFeaturizer, + data_type: str = "train", + max_samples_per_language: Optional[int] = None + ): + self.config = config + self.languages = languages + self.vocab = vocab + self.speech_featurizer = speech_featurizer + self.data_type = data_type + self.max_samples_per_language = max_samples_per_language + self.dataset_cache = {} + + # Initialize language mapping + self.language_to_id = {lang: idx for idx, lang in enumerate(languages)} + self.id_to_language = {idx: lang for lang, idx in self.language_to_id.items()} + + # Load FLEURS dataset for multiple languages + self.load_datasets() + + def load_datasets(self): + """Load datasets for all specified languages""" + for lang in tqdm(self.languages, desc="Loading languages"): + try: + # Load FLEURS dataset for the language + dataset = load_dataset("google/fleurs", lang, split=self.data_type) + + # Apply sampling if specified + if self.max_samples_per_language: + dataset = dataset.select(range(min(len(dataset), self.max_samples_per_language))) + + self.dataset_cache[lang] = dataset + print(f"Loaded {len(dataset)} samples for {lang}") + except Exception as e: + print(f"Error loading dataset for {lang}: {str(e)}") + + def prepare_audio(self, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray: + """Process audio data to extract features""" + if sampling_rate != self.config.speech_config['sample_rate']: + # Resample if necessary + import librosa + audio_data = librosa.resample( + audio_data, + orig_sr=sampling_rate, + target_sr=self.config.speech_config['sample_rate'] + ) + + # Extract features using the speech featurizer + features = self.speech_featurizer.extract(audio_data) + return features + + def get_batch_generator(self, batch_size: int): + """Generate batches of data""" + while True: + for lang in self.languages: + dataset = self.dataset_cache[lang] + + for i in range(0, len(dataset), batch_size): + batch_data = dataset[i:i + batch_size] + + features_list = [] + labels = [] + + for item in batch_data: + # Process audio + audio_data = item['audio']['array'] + sampling_rate = item['audio']['sampling_rate'] + features = self.prepare_audio(audio_data, sampling_rate) + features_list.append(features) + + # Get language label + labels.append(self.language_to_id[lang]) + + # Pad features to same length + max_len = max(feat.shape[0] for feat in features_list) + padded_features = np.zeros((len(features_list), max_len, features_list[0].shape[1])) + + for j, feat in enumerate(features_list): + padded_features[j, :feat.shape[0], :] = feat + + yield { + 'features': padded_features, + 'input_lengths': np.array([len(feat) for feat in features_list]), + 'labels': np.array(labels) + } + + def save_language_mapping(self, file_path: str): + """Save language to ID mapping""" + with open(file_path, 'w') as f: + json.dump({ + 'language_to_id': self.language_to_id, + 'id_to_language': self.id_to_language + }, f, indent=2) + + @classmethod + def load_language_mapping(cls, file_path: str) -> Dict: + """Load language to ID mapping""" + with open(file_path, 'r') as f: + return json.load(f) + + def get_num_languages(self) -> int: + """Get total number of languages""" + return len(self.languages) + + def get_language_list(self) -> List[str]: + """Get list of all languages""" + return self.languages.copy() \ No newline at end of file diff --git a/predict_by_pb.py b/predict_by_pb.py index 9be4dc6..1495b27 100644 --- a/predict_by_pb.py +++ b/predict_by_pb.py @@ -1,30 +1,86 @@ -from signal import signal +import os +# Force CPU only +os.environ['CUDA_VISIBLE_DEVICES'] = '-1' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + +# Must import tensorflow after setting environment variables import tensorflow as tf -gpus = tf.config.list_physical_devices('GPU') -tf.config.set_visible_devices(gpus[0:1], 'GPU') +print("TensorFlow version:", tf.__version__) +print("Using CPU only") + from vocab.vocab import Vocab import librosa import numpy as np import sys -import os from tqdm import tqdm from sklearn.metrics import accuracy_score +def load_model(model_path): + try: + # Load model in CPU mode + with tf.device('/CPU:0'): + return tf.saved_model.load(model_path) + except Exception as e: + print(f"Error loading model: {str(e)}") + return None -vocab = Vocab("vocab/vocab.txt") -model = tf.saved_model.load('saved_models/lang14/pb/2/') - - -def predict_wav(wav_path): - signal, _ = librosa.load(wav_path, sr=16000) - output, prob = model.predict_pb(signal) - language = vocab.token_list[output.numpy()] - print(language, prob.numpy()*100) - - return output.numpy(), prob.numpy() - +def predict_wav(wav_path, model, vocab): + try: + # Load and preprocess audio + signal, _ = librosa.load(wav_path, sr=16000) + + # Convert to tensor and ensure CPU operation + with tf.device('/CPU:0'): + signal = tf.convert_to_tensor(signal, dtype=tf.float32) + + # Make prediction + if hasattr(model, 'predict_pb'): + output = model.predict_pb(signal) + else: + output = model(signal) + + if isinstance(output, dict): + pred = output.get("output_0", None) + prob = output.get("output_1", None) + elif isinstance(output, (list, tuple)) and len(output) == 2: + pred, prob = output + else: + print("Unexpected model output format") + return None, None + + # Get prediction + pred_idx = tf.argmax(pred).numpy() + probability = tf.reduce_max(tf.nn.softmax(prob)).numpy() + + language = vocab.token_list[pred_idx] + print(f"Detected language: {language} (confidence: {probability*100:.2f}%)") + + return pred_idx, probability + + except Exception as e: + print(f"Error during prediction: {str(e)}") + return None, None if __name__ == '__main__': - wav_path = sys.argv[1] - predict_wav(wav_path) - + try: + # Initialize vocabulary + vocab = Vocab("vocab/vocab.txt") + + # Load model + print("Loading model...") + model = load_model('saved_models/lang14/pb/2/') + + if model is None: + print("Failed to load model. Exiting.") + sys.exit(1) + + # Make prediction + print("Making prediction...") + predict_wav("test_audios/french.wav", model, vocab) + + except Exception as e: + print(f"Error: {str(e)}") + print("\nTroubleshooting tips:") + print("1. Make sure the model file exists in saved_models/lang14/pb/2/") + print("2. Make sure the vocabulary file exists in vocab/vocab.txt") + print("3. Make sure the audio file exists in test_audios/french.wav") diff --git a/requirements.txt b/requirements.txt index d3be841..f2f6e5a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ -tensorflow==2.4.1 -tensorflow-gpu==2.4.1 -tensorflow-addons==0.15.0 -matplotlib==3.5.0 -numpy==1.19.5 -scikit-learn==1.0.1 -librosa==0.8.1 -SoundFile==0.10.3.post1 -PyYAML==6.0 \ No newline at end of file +tensorflow-cpu==2.11.0 +numpy==1.23.5 +librosa==0.10.1 +soundfile==0.12.1 +matplotlib==3.7.1 +scikit-learn==1.2.2 +PyYAML==6.0 +tqdm>=4.65.0 +pandas>=2.0.0 +transformers>=4.30.0 +datasets>=2.12.0 diff --git a/saved_models/.DS_Store b/saved_models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..aa05c57011fedb823c2f90152f48959fde48b3e1 GIT binary patch literal 6148 zcmeHK%}T>S5Z-NTn^J@v6nb3nTCmkx6fYsx7cim+m70)JgE3p0)*MP9XMG``#OHBl zcXJ2^youNu*!^bbXE*af_J=XX-BmbX%w~*P&=5H)6@unUSIqh7W99Gv2rrE^ayO7D!kCN#^B}CyALhf&)B$SyeS4o(vTu(c!w$+~4-Sv8OGU|!J@O0A? z>$9VNPYjPon@!u=KR7(UnmorZiG0&Ua^PIauE7G{K`GVr>MfF3CNr=XIYlfXF+dCu z1H{1gF<{ODtFwKjQ^mvpG4KNexIYMJh>pQrquM&4!|OBp8;B^N<68nz7<3Hg8o>j? zbt<4v<>raObvoFEiE|9*8g)A3YGs(mtXw=^xLO_TLWMK#Xr!JPAO^||)b-HD^Zycl znaW3gIfX{V05R~-7~qYuH+G>YbGCjf56@Zw?GYLZ#^tDhfL^);zySA=u5xO>Kpo;7 XgSkeW1??&wkS+p>5bB75Utr)1R!mC6 literal 0 HcmV?d00001 diff --git a/saved_models/lang14/.DS_Store b/saved_models/lang14/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..12f0c235ff07955bbe2306207348bb23f899f07f GIT binary patch literal 6148 zcmeHKQA@)x5Kgw~GKSCxg*^s*9k}TniZ7*`f53`9sLYlQE!JkNoqZUCKIG3PFDlXGv^* ze;x4KYi!2Cu*ha#zdy@+y$@ciH@3E$qAA+q&VN*ep9lFO^Md&eT9-;CVWkJ*RXi;w z&hAL1c@U@5xlV|qDTLf!$7!SrPc70Y)49HJh?ZzgoX&DNJRWvse{izu%H`={uPX;f z!_}%K_Vy3XE+$XObE;kqnH;#*vSYD=H&FCmUi^8QsPqBsDz=ItBnF59Vt^PR25fB* zZK8r0AO?tmZw%o1AfX|;7E6Qr=ztEd&**O;qJWNX2}EhpwOASi4+uA@fF_mOCk8j^ zFfMJJYq2zF(izt)!#sB7>hZ$$>M$;KIODEC>WKkjV4Z=wF&(V`=kUvHeB`gEkVOm- z1OJQx-WdC14>o0M>yPbWt(DO3p`l=0i3$klYnK2p@EqxAr}0bFAM`ap#soA(j>-l>_r}ncO-AH6h6+LZ2&YM$-$ z!K}R>7TGjRzOTRA{p4G<`u0vkG(=0>`%kLy^B|vRUNE~w>q@C4tn?tfjwi*~**jNh z9>nQnrVHX|0x5SlaT=+@Q}Z;+bggd!qA8kVr@dGVkB1%E8=NdVa&g-4cI4n_xLh{H z{=wnd<@h;yN!6R7k^|FPb}d%$4hm=eSAUi!Dt**e6(b}Dhyh}N7$62#4A_%Iv{o+C zhyh}N7|;yh{ve?tx)w`=dh37=ug@5-BcgzgZwW+c(Y06_ga`;XrGTcC+b0G$<=~e# z&$U<@H06xznPD6|bM<)PdUo(jozA#xka}W(7+7bZW=sdq{|oqKHa_y#OUNPyh=G5` z0I!eykp~}T&(Cq7Hek#nK>-f__~NNEZQ3 L2zA83FEH>0R6a*$ literal 0 HcmV?d00001 diff --git a/saved_models/lang14/pb/2/.DS_Store b/saved_models/lang14/pb/2/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..19ad97c2515317ea36d3e4a7a8782588c85f5515 GIT binary patch literal 6148 zcmeHK%}T>S5T0$TO(;SS3Oz1(E!b);ikDF93mDOZN=<0dV9b`LwTDv3SzpK}@p+ut z-4seIcoMNQF!SxsPnP)-cC!E=I+K1Apb7vEDq+DzvqmURx*`SZAr$%<0fZ340Mc-} zlFg34$N;@N7hWNN5kzPg?=KyOu?jKz5f0-h9XA@EqEIYtZI_+0Q*rM7CpGoQ!FZB( zg3&F_u9S+yd>n+=(O}xE?w+e;97M@rqzj_{08{R6qNJ~;9W_b%sjl@+z;PY7SFO!v z?c;V`Hd`n2x}2RJHtMo<)Sk~>XK(-D?6UV9za;9-&?)fC)3ReRhj$#T>Fm`XC9z5# zF`8HA6^zUPGr$aN3Ilfkb1IwCF!#v}Fatkkfc6K8O6Xb44C<`|4gEe+yhKQXW4cQa zN{gPw%pgWkgegTdrNUk@gegbAw0WMz%%CX;p;yM|*p-F7p$NS?`lU_>;TdGh3@`)B z3>3|@O6UK@_xJzhBAzh=%)q~5K$N best_acc: - best_acc = precision - model.save_weights(dir_current + '/best/' + 'model') - model.save_weights(dir_current + '/last/' + 'model') - template = ("\rEpoch {}, Loss: {:.4f}, Val Loss: {:.4f}, " - "Val Acc: {:.4f}, test ACC: {:.4f},F1: {:.4f}, precision: {:.4f}, recall: {:.4f}, Time Cost: {:.2f} sec") - text = template.format(epoch, train_loss / config.running_config['train_steps'], - dev_loss/ config.running_config['dev_steps'], dev_accuracy *100, - test_acc*100, test_f1*100, precision*100, recall*100, time.time() - start) - print(colored(text, 'cyan')) - log_file.write(text) + # Reset metrics + train_loss.reset_states() + val_loss.reset_states() + train_accuracy.reset_states() + val_accuracy.reset_states() + + # Training + for step, inputs in enumerate(train_dist_dataset): + loss = distributed_train_step(inputs) + if step % 10 == 0: + template = 'Epoch {}, Step {}, Loss: {:.4f}, Accuracy: {:.4f}' + print(colored(template.format( + epoch, step + 1, + train_loss.result(), + train_accuracy.result() + ), 'green')) + + # Validation + all_predictions = [] + all_labels = [] + for inputs in eval_dist_dataset: + predictions, labels = distributed_test_step(inputs) + all_predictions.extend(tf.argmax(predictions, axis=-1).numpy()) + all_labels.extend(labels.numpy()) + + # Calculate metrics + val_f1 = f1_score(y_true=all_labels, y_pred=all_predictions, average='macro') + val_precision = precision_score(y_true=all_labels, y_pred=all_predictions, average='macro', zero_division=1) + val_recall = recall_score(y_true=all_labels, y_pred=all_predictions, average='macro') + + # Save best model + if val_precision > best_accuracy: + best_accuracy = val_precision + model.save_weights(os.path.join(dir_current, 'best', 'model')) + model.save_weights(os.path.join(dir_current, 'last', 'model')) + + # Log results + template = 'Epoch {}, Loss: {:.4f}, Accuracy: {:.4f}, Val Loss: {:.4f}, Val Accuracy: {:.4f}, F1: {:.4f}, Precision: {:.4f}, Recall: {:.4f}, Time: {:.2f}s' + print(template.format( + epoch, + train_loss.result(), + train_accuracy.result(), + val_loss.result(), + val_accuracy.result(), + val_f1, + val_precision, + val_recall, + time.time() - start_time + )) + log_file.write(template.format( + epoch, + train_loss.result(), + train_accuracy.result(), + val_loss.result(), + val_accuracy.result(), + val_f1, + val_precision, + val_recall, + time.time() - start_time + ) + '\n') log_file.flush() - plot_train_loss.append(train_loss / config.running_config['train_steps']) - plot_dev_loss.append(dev_loss / config.running_config['dev_steps']) - plot_acc.append(test_acc) - plot_precision.append(precision) - ckpt_manager.save() - - plt.plot(plot_train_loss, '-r', label='train_loss') - plt.title('Train Loss') - plt.xlabel('Epochs') - plt.savefig(dir_current + '/loss.png') - #plot dev - plt.clf() - plt.plot(plot_dev_loss, '-g', label='dev_loss') - plt.title('dev Loss') - plt.xlabel('Epochs') - plt.savefig(dir_current + '/dev_loss.png') - - # plot acc curve - plt.clf() - plt.plot(plot_acc, 'b-', label='acc') - plt.title('Accuracy') - plt.xlabel('Epochs') - plt.savefig(dir_current + '/acc.png') - # plot f1 curve - plt.clf() - plt.plot(plot_precision, 'y-', label='f1-score') - plt.title('F1') - plt.xlabel('Epochs') - plt.savefig(dir_current + '/f1-score.png') - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Spoken_language_identification Model training") parser.add_argument("--config_file", type=str, default='./configs/config.yml', help="Config File Path") args = parser.parse_args() kwargs = vars(args) - with mirrored_strategy.scope(): + with strategy.scope(): train(**kwargs) \ No newline at end of file diff --git a/train_multilingual.py b/train_multilingual.py new file mode 100644 index 0000000..b9bf9c2 --- /dev/null +++ b/train_multilingual.py @@ -0,0 +1,147 @@ +import os +import json +import tensorflow as tf +from tqdm import tqdm +from datetime import datetime +from multilingual_dataset import MultilingualDataset +from featurizers.speech_featurizers import SpeechFeaturizer +from configs.config import Config +from vocab.vocab import Vocab + +def setup_mixed_precision(): + """Setup mixed precision for better performance on ARM32""" + policy = tf.keras.mixed_precision.Policy('mixed_float16') + tf.keras.mixed_precision.set_global_policy(policy) + +def create_model(config, num_languages): + """Create the model with support for multiple languages""" + inputs = tf.keras.Input(shape=(None, config.speech_config['num_feature_bins'])) + x = inputs + + # CNN layers + for filters, kernel in zip(config.model_config['filters'], config.model_config['kernel_size']): + x = tf.keras.layers.Conv2D( + filters=filters, + kernel_size=kernel, + padding='same', + activation='relu' + )(tf.expand_dims(x, axis=-1)) + x = tf.keras.layers.BatchNormalization()(x) + x = tf.squeeze(x, axis=-1) + + # BiLSTM layers + x = tf.keras.layers.Bidirectional( + tf.keras.layers.LSTM( + config.model_config['rnn_cell'], + return_sequences=True + ) + )(x) + + # Global pooling + x = tf.keras.layers.GlobalAveragePooling1D()(x) + + # Output layer + outputs = tf.keras.layers.Dense(num_languages, activation='softmax')(x) + + model = tf.keras.Model(inputs=inputs, outputs=outputs) + return model + +def main(): + # Load configuration + config = Config("configs/config.yml") + + # Setup mixed precision if enabled + if config.optimizer_config.get('use_mixed_precision', False): + setup_mixed_precision() + + # Load language configuration + with open(config.dataset_config['languages_file'], 'r') as f: + languages_config = json.load(f) + languages = languages_config['supported_languages'] + + # Initialize components + vocab = Vocab(config.dataset_config['vocabulary']) + speech_featurizer = SpeechFeaturizer(config.speech_config) + + # Create datasets + train_dataset = MultilingualDataset( + config=config, + languages=languages, + vocab=vocab, + speech_featurizer=speech_featurizer, + data_type='train', + max_samples_per_language=config.dataset_config['max_samples_per_language'] + ) + + val_dataset = MultilingualDataset( + config=config, + languages=languages, + vocab=vocab, + speech_featurizer=speech_featurizer, + data_type='validation', + max_samples_per_language=config.dataset_config['max_samples_per_language'] // 10 + ) + + # Create model + model = create_model(config, len(languages)) + + # Compile model + optimizer = tf.keras.optimizers.Adam( + learning_rate=config.optimizer_config['max_lr'], + beta_1=config.optimizer_config['beta1'], + beta_2=config.optimizer_config['beta2'], + epsilon=config.optimizer_config['epsilon'] + ) + + model.compile( + optimizer=optimizer, + loss='sparse_categorical_crossentropy', + metrics=['accuracy'] + ) + + # Setup callbacks + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + save_dir = os.path.join('saved_weights', 'multilingual', timestamp) + os.makedirs(save_dir, exist_ok=True) + + callbacks = [ + tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(save_dir, 'epoch_{epoch:02d}'), + save_weights_only=True, + save_freq='epoch' + ), + tf.keras.callbacks.TensorBoard( + log_dir=os.path.join('logs', timestamp), + update_freq='batch' + ), + tf.keras.callbacks.EarlyStopping( + monitor='val_loss', + patience=5, + restore_best_weights=True + ) + ] + + # Train model + train_generator = train_dataset.get_batch_generator(config.running_config['batch_size']) + val_generator = val_dataset.get_batch_generator(config.running_config['batch_size']) + + steps_per_epoch = config.running_config['train_steps'] + validation_steps = config.running_config['dev_steps'] + + model.fit( + train_generator, + steps_per_epoch=steps_per_epoch, + validation_data=val_generator, + validation_steps=validation_steps, + epochs=config.running_config['num_epochs'], + callbacks=callbacks + ) + + # Save final model + model.save_weights(os.path.join(save_dir, 'final')) + + # Save language mapping + train_dataset.save_language_mapping(os.path.join(save_dir, 'language_mapping.json')) + +if __name__ == '__main__': + main() \ No newline at end of file From c1072f97c37f98fc34d2ed8cb9b4abb803391eb5 Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Mon, 2 Jun 2025 14:05:31 +0530 Subject: [PATCH 02/14] changes for multilingual features --- multilingual_dataset.py | 8 ++++---- train_multilingual.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/multilingual_dataset.py b/multilingual_dataset.py index 2a620e1..c270247 100644 --- a/multilingual_dataset.py +++ b/multilingual_dataset.py @@ -6,7 +6,7 @@ from typing import List, Dict, Optional from datasets import load_dataset, Audio from transformers import Wav2Vec2Processor -from featurizers.speech_featurizers import SpeechFeaturizer +from featurizers.speech_featurizers import NumpySpeechFeaturizer from configs.config import Config from vocab.vocab import Vocab @@ -16,7 +16,7 @@ def __init__( config: Config, languages: List[str], vocab: Vocab, - speech_featurizer: SpeechFeaturizer, + speech_featurizer: NumpySpeechFeaturizer, data_type: str = "train", max_samples_per_language: Optional[int] = None ): @@ -52,7 +52,7 @@ def load_datasets(self): print(f"Error loading dataset for {lang}: {str(e)}") def prepare_audio(self, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray: - """Process audio data to extract features""" + """Process audio data to extract features using NumPy-based extraction""" if sampling_rate != self.config.speech_config['sample_rate']: # Resample if necessary import librosa @@ -62,7 +62,7 @@ def prepare_audio(self, audio_data: np.ndarray, sampling_rate: int) -> np.ndarra target_sr=self.config.speech_config['sample_rate'] ) - # Extract features using the speech featurizer + # Extract features using NumPy-based feature extraction features = self.speech_featurizer.extract(audio_data) return features diff --git a/train_multilingual.py b/train_multilingual.py index b9bf9c2..5f55e73 100644 --- a/train_multilingual.py +++ b/train_multilingual.py @@ -4,7 +4,7 @@ from tqdm import tqdm from datetime import datetime from multilingual_dataset import MultilingualDataset -from featurizers.speech_featurizers import SpeechFeaturizer +from featurizers.speech_featurizers import NumpySpeechFeaturizer from configs.config import Config from vocab.vocab import Vocab @@ -61,7 +61,7 @@ def main(): # Initialize components vocab = Vocab(config.dataset_config['vocabulary']) - speech_featurizer = SpeechFeaturizer(config.speech_config) + speech_featurizer = NumpySpeechFeaturizer(config.speech_config) # Create datasets train_dataset = MultilingualDataset( From 69900f1bb8bad9d3b1c2f4b6bf2ace3b73f04a8a Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Mon, 2 Jun 2025 14:17:12 +0530 Subject: [PATCH 03/14] changes for multilingual dataset --- configs/config.yml | 5 +++- multilingual_dataset.py | 58 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 7 deletions(-) diff --git a/configs/config.yml b/configs/config.yml index 37c69da..517fbf5 100644 --- a/configs/config.yml +++ b/configs/config.yml @@ -24,6 +24,7 @@ dataset_config: vocabulary: vocab/vocab.txt data_path: ./data/wavs/ corpus_name: ./data/multilingual/ + fleurs_path: ./data/fleurs/ file_nums: 1 max_audio_length: 2000 shuffle_size: 1200 @@ -31,10 +32,12 @@ dataset_config: suffix: .txt load_type: multilingual train: train - dev: dev + dev: validation test: test languages_file: configs/languages.json max_samples_per_language: 10000 + audio_format: wav + metadata_format: json optimizer_config: init_steps: 0 diff --git a/multilingual_dataset.py b/multilingual_dataset.py index c270247..5ca36be 100644 --- a/multilingual_dataset.py +++ b/multilingual_dataset.py @@ -3,9 +3,9 @@ import numpy as np import pandas as pd from tqdm import tqdm -from typing import List, Dict, Optional -from datasets import load_dataset, Audio -from transformers import Wav2Vec2Processor +from typing import List, Dict, Optional, Any +from datasets import Dataset, Audio +import soundfile as sf from featurizers.speech_featurizers import NumpySpeechFeaturizer from configs.config import Config from vocab.vocab import Vocab @@ -35,12 +35,58 @@ def __init__( # Load FLEURS dataset for multiple languages self.load_datasets() + def load_local_dataset(self, lang: str) -> Dataset: + """Load dataset from local files for a specific language""" + lang_path = os.path.join(self.config.dataset_config['fleurs_path'], lang, self.data_type) + if not os.path.exists(lang_path): + raise ValueError(f"Dataset path not found: {lang_path}") + + # Load metadata + metadata_file = os.path.join(lang_path, f"metadata.{self.config.dataset_config['metadata_format']}") + if not os.path.exists(metadata_file): + raise ValueError(f"Metadata file not found: {metadata_file}") + + with open(metadata_file, 'r', encoding='utf-8') as f: + metadata = json.load(f) + + # Create dataset dictionary + dataset_dict = { + 'audio': [], + 'transcription': [], + 'language': [], + 'id': [] + } + + # Process each sample + for item in metadata['data']: + audio_path = os.path.join(lang_path, 'audio', f"{item['id']}.{self.config.dataset_config['audio_format']}") + if not os.path.exists(audio_path): + print(f"Warning: Audio file not found: {audio_path}") + continue + + try: + # Load audio file + audio_data, sample_rate = sf.read(audio_path) + dataset_dict['audio'].append({ + 'array': audio_data, + 'sampling_rate': sample_rate, + 'path': audio_path + }) + dataset_dict['transcription'].append(item.get('transcription', '')) + dataset_dict['language'].append(lang) + dataset_dict['id'].append(item['id']) + except Exception as e: + print(f"Error loading audio file {audio_path}: {str(e)}") + continue + + return Dataset.from_dict(dataset_dict) + def load_datasets(self): - """Load datasets for all specified languages""" + """Load datasets for all specified languages from local files""" for lang in tqdm(self.languages, desc="Loading languages"): try: - # Load FLEURS dataset for the language - dataset = load_dataset("google/fleurs", lang, split=self.data_type) + # Load local dataset for the language + dataset = self.load_local_dataset(lang) # Apply sampling if specified if self.max_samples_per_language: From dc863b8557590e014266df83d22c8436b5e3b0a6 Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Mon, 2 Jun 2025 14:50:42 +0530 Subject: [PATCH 04/14] download and set fleurs dataset --- configs/languages.json | 19 +++-- download_fleurs.py | 184 ++++++++++++++++++++++++++++++++++++++++ multilingual_dataset.py | 131 +++++++++++++++------------- requirements.txt | 3 +- 4 files changed, 269 insertions(+), 68 deletions(-) create mode 100644 download_fleurs.py diff --git a/configs/languages.json b/configs/languages.json index 78986ba..eecf280 100644 --- a/configs/languages.json +++ b/configs/languages.json @@ -2,14 +2,7 @@ "supported_languages": [ "en", "es", "fr", "de", "it", "pt", "nl", "pl", "ru", "uk", "ar", "hi", "bn", "ta", "te", "mr", "ur", "fa", "tr", "he", - "th", "vi", "id", "ms", "fil", "ja", "ko", "zh", "yue", "sw", - "af", "am", "az", "be", "bg", "ca", "cs", "da", "el", "et", - "eu", "fi", "ga", "gl", "gu", "ha", "hr", "hu", "hy", "ig", - "is", "ka", "kk", "km", "kn", "ky", "lb", "lg", "lt", "lv", - "mk", "ml", "mn", "mt", "my", "ne", "no", "pa", "ps", "ro", - "si", "sk", "sl", "sn", "so", "sq", "sr", "sv", "tg", "tk", - "uz", "wo", "xh", "yi", "yo", "zu", "as", "bho", "doi", "mai", - "or", "raj", "sa", "sd", "cy", "fo", "gd", "kw", "fy", "rm" + "th", "vi", "id", "ms", "fil", "ja", "ko", "zh" ], "language_names": { "en": "English", @@ -31,6 +24,14 @@ "ur": "Urdu", "fa": "Persian", "tr": "Turkish", - "he": "Hebrew" + "he": "Hebrew", + "th": "Thai", + "vi": "Vietnamese", + "id": "Indonesian", + "ms": "Malay", + "fil": "Filipino", + "ja": "Japanese", + "ko": "Korean", + "zh": "Chinese" } } \ No newline at end of file diff --git a/download_fleurs.py b/download_fleurs.py new file mode 100644 index 0000000..7ff0280 --- /dev/null +++ b/download_fleurs.py @@ -0,0 +1,184 @@ +import os +import json +import argparse +import shutil +import time +from tqdm import tqdm +from datasets import load_dataset, get_dataset_config_names +import soundfile as sf +import numpy as np +from pathlib import Path + +# All FLEURS languages +ALL_LANGUAGES = [ + 'af', 'am', 'ar', 'as', 'az', 'be', 'bg', 'bn', 'br', 'bs', 'ca', 'cs', 'cy', 'da', + 'de', 'el', 'en', 'es', 'et', 'eu', 'fa', 'fi', 'fr', 'ga', 'gl', 'gu', 'ha', 'he', + 'hi', 'hr', 'hu', 'hy', 'id', 'ig', 'is', 'it', 'ja', 'jv', 'ka', 'kk', 'km', 'kn', + 'ko', 'ky', 'lb', 'lg', 'ln', 'lo', 'lt', 'lv', 'mg', 'mk', 'ml', 'mn', 'mr', 'ms', + 'my', 'ne', 'nl', 'no', 'ny', 'or', 'pa', 'pl', 'ps', 'pt', 'ro', 'ru', 'rw', 'sd', + 'si', 'sk', 'sl', 'sn', 'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', + 'tk', 'tr', 'uk', 'ur', 'uz', 'vi', 'wo', 'xh', 'yi', 'yo', 'zh', 'zu' +] + +def ensure_dir(path): + """Create directory if it doesn't exist""" + Path(path).mkdir(parents=True, exist_ok=True) + +def save_audio(audio_data, sample_rate, output_path): + """Save audio data to WAV file""" + sf.write(output_path, audio_data, sample_rate) + +def download_language(lang, output_dir, splits=None, retry_count=3, retry_delay=5): + """Download and organize dataset for a specific language with retries""" + if splits is None: + splits = ['train', 'validation', 'test'] + + lang_dir = os.path.join(output_dir, lang) + print(f"\nProcessing language: {lang}") + + for split in splits: + print(f"\nDownloading {split} split...") + split_dir = os.path.join(lang_dir, split) + audio_dir = os.path.join(split_dir, 'audio') + + # Skip if already downloaded + metadata_path = os.path.join(split_dir, 'metadata.json') + if os.path.exists(metadata_path): + print(f"Skipping {lang} {split} - already downloaded") + continue + + ensure_dir(audio_dir) + + # Load dataset with retries + dataset = None + for attempt in range(retry_count): + try: + dataset = load_dataset("google/fleurs", lang, split=split) + break + except Exception as e: + if attempt < retry_count - 1: + print(f"Attempt {attempt + 1} failed for {lang} {split}: {str(e)}") + print(f"Retrying in {retry_delay} seconds...") + time.sleep(retry_delay) + else: + print(f"Error downloading {lang} {split} after {retry_count} attempts: {str(e)}") + return False + + if dataset is None: + continue + + # Prepare metadata + metadata = { + 'data': [], + 'lang': lang, + 'split': split + } + + # Process each example + for idx, item in enumerate(tqdm(dataset, desc=f"Processing {split}")): + try: + # Extract audio + audio_data = item['audio']['array'] + sample_rate = item['audio']['sampling_rate'] + + # Generate ID + item_id = f"{lang}_{split}_{idx:06d}" + + # Save audio file + audio_path = os.path.join(audio_dir, f"{item_id}.wav") + save_audio(audio_data, sample_rate, audio_path) + + # Add to metadata + metadata['data'].append({ + 'id': item_id, + 'transcription': item.get('transcription', ''), + 'raw_transcription': item.get('raw_transcription', ''), + 'language': item.get('language', lang), + 'gender': item.get('gender', ''), + 'lang_id': item.get('lang_id', -1) + }) + + except Exception as e: + print(f"Error processing item {idx} in {lang} {split}: {str(e)}") + continue + + # Save metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, ensure_ascii=False, indent=2) + + print(f"Saved {len(metadata['data'])} examples for {lang} {split}") + + return True + +def download_languages_in_batches(languages, output_dir, batch_size=5, splits=None): + """Download languages in batches to manage memory usage""" + total_languages = len(languages) + successful = [] + failed = [] + + for i in range(0, total_languages, batch_size): + batch = languages[i:i + batch_size] + print(f"\nProcessing batch {i//batch_size + 1} of {(total_languages + batch_size - 1)//batch_size}") + print(f"Languages in this batch: {', '.join(batch)}") + + for lang in batch: + try: + if download_language(lang, output_dir, splits): + successful.append(lang) + else: + failed.append(lang) + except Exception as e: + print(f"Failed to download {lang}: {str(e)}") + failed.append(lang) + + # Clear some memory + if i + batch_size < total_languages: + print("\nClearing memory before next batch...") + time.sleep(5) # Give some time for memory cleanup + + return successful, failed + +def main(): + parser = argparse.ArgumentParser(description='Download and organize FLEURS dataset') + parser.add_argument('--output_dir', type=str, default='./data/fleurs', + help='Output directory for the dataset') + parser.add_argument('--languages', type=str, nargs='+', + help='List of language codes to download (default: all languages)') + parser.add_argument('--splits', type=str, nargs='+', + default=['train', 'validation', 'test'], + help='Dataset splits to download') + parser.add_argument('--batch_size', type=int, default=5, + help='Number of languages to download in parallel') + args = parser.parse_args() + + # Use all languages if none specified + languages = args.languages if args.languages else ALL_LANGUAGES + + # Create output directory + ensure_dir(args.output_dir) + + # Download languages in batches + print(f"Starting download of {len(languages)} languages in batches of {args.batch_size}") + successful, failed = download_languages_in_batches( + languages, args.output_dir, args.batch_size, args.splits + ) + + # Print summary + print("\n=== Download Summary ===") + print(f"Successfully downloaded: {len(successful)} languages") + print(f"Failed to download: {len(failed)} languages") + + if failed: + print("\nFailed languages:") + print(", ".join(failed)) + + # Save failed languages to file for retry + failed_file = os.path.join(args.output_dir, "failed_languages.txt") + with open(failed_file, 'w') as f: + f.write("\n".join(failed)) + print(f"\nFailed languages list saved to: {failed_file}") + print("You can retry failed languages using:") + print(f"python download_fleurs.py --languages {' '.join(failed)}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/multilingual_dataset.py b/multilingual_dataset.py index 5ca36be..0c5c2bf 100644 --- a/multilingual_dataset.py +++ b/multilingual_dataset.py @@ -3,9 +3,9 @@ import numpy as np import pandas as pd from tqdm import tqdm -from typing import List, Dict, Optional, Any -from datasets import Dataset, Audio -import soundfile as sf +from typing import List, Dict, Optional +from datasets import load_dataset, Dataset, load_from_disk +from huggingface_hub import HfFileSystem from featurizers.speech_featurizers import NumpySpeechFeaturizer from configs.config import Config from vocab.vocab import Vocab @@ -35,59 +35,59 @@ def __init__( # Load FLEURS dataset for multiple languages self.load_datasets() - def load_local_dataset(self, lang: str) -> Dataset: - """Load dataset from local files for a specific language""" - lang_path = os.path.join(self.config.dataset_config['fleurs_path'], lang, self.data_type) - if not os.path.exists(lang_path): - raise ValueError(f"Dataset path not found: {lang_path}") - - # Load metadata - metadata_file = os.path.join(lang_path, f"metadata.{self.config.dataset_config['metadata_format']}") - if not os.path.exists(metadata_file): - raise ValueError(f"Metadata file not found: {metadata_file}") - - with open(metadata_file, 'r', encoding='utf-8') as f: - metadata = json.load(f) - - # Create dataset dictionary - dataset_dict = { - 'audio': [], - 'transcription': [], - 'language': [], - 'id': [] - } - - # Process each sample - for item in metadata['data']: - audio_path = os.path.join(lang_path, 'audio', f"{item['id']}.{self.config.dataset_config['audio_format']}") - if not os.path.exists(audio_path): - print(f"Warning: Audio file not found: {audio_path}") - continue - - try: - # Load audio file - audio_data, sample_rate = sf.read(audio_path) - dataset_dict['audio'].append({ - 'array': audio_data, - 'sampling_rate': sample_rate, - 'path': audio_path - }) - dataset_dict['transcription'].append(item.get('transcription', '')) - dataset_dict['language'].append(lang) - dataset_dict['id'].append(item['id']) - except Exception as e: - print(f"Error loading audio file {audio_path}: {str(e)}") - continue + def find_dataset_path(self, lang: str) -> Optional[str]: + """Find the dataset path in HuggingFace cache""" + cache_dir = os.path.expanduser("~/.cache/huggingface/datasets") + dataset_dir = os.path.join(cache_dir, "google-fleurs", lang) + + if not os.path.exists(dataset_dir): + print(f"Dataset directory not found for language {lang}") + return None + + # Look for the downloaded version + versions = [d for d in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, d))] + if not versions: + print(f"No dataset versions found for language {lang}") + return None + + # Use the latest version + latest_version = sorted(versions)[-1] + dataset_path = os.path.join(dataset_dir, latest_version) + + return dataset_path if os.path.exists(dataset_path) else None - return Dataset.from_dict(dataset_dict) + def load_local_dataset(self, lang: str) -> Optional[Dataset]: + """Load dataset from HuggingFace cache""" + try: + # Find the dataset path + dataset_path = self.find_dataset_path(lang) + if dataset_path is None: + return None + + # Load the dataset + dataset = load_from_disk(dataset_path) + + # Get the appropriate split + if self.data_type in dataset: + return dataset[self.data_type] + else: + print(f"Split {self.data_type} not found for language {lang}") + return None + + except Exception as e: + print(f"Error loading dataset for language {lang}: {str(e)}") + return None def load_datasets(self): - """Load datasets for all specified languages from local files""" + """Load datasets for all specified languages""" for lang in tqdm(self.languages, desc="Loading languages"): try: # Load local dataset for the language dataset = self.load_local_dataset(lang) + if dataset is None: + continue + # Apply sampling if specified if self.max_samples_per_language: dataset = dataset.select(range(min(len(dataset), self.max_samples_per_language))) @@ -116,23 +116,38 @@ def get_batch_generator(self, batch_size: int): """Generate batches of data""" while True: for lang in self.languages: + if lang not in self.dataset_cache: + continue + dataset = self.dataset_cache[lang] + indices = list(range(len(dataset))) + + if self.config.dataset_config.get('shuffle', True): + np.random.shuffle(indices) - for i in range(0, len(dataset), batch_size): - batch_data = dataset[i:i + batch_size] + for i in range(0, len(indices), batch_size): + batch_indices = indices[i:i + batch_size] + batch_data = dataset.select(batch_indices) features_list = [] labels = [] for item in batch_data: - # Process audio - audio_data = item['audio']['array'] - sampling_rate = item['audio']['sampling_rate'] - features = self.prepare_audio(audio_data, sampling_rate) - features_list.append(features) - - # Get language label - labels.append(self.language_to_id[lang]) + try: + # Process audio + audio_data = item['audio']['array'] + sampling_rate = item['audio']['sampling_rate'] + features = self.prepare_audio(audio_data, sampling_rate) + features_list.append(features) + + # Get language label + labels.append(self.language_to_id[lang]) + except Exception as e: + print(f"Error processing item in {lang}: {str(e)}") + continue + + if not features_list: + continue # Pad features to same length max_len = max(feat.shape[0] for feat in features_list) diff --git a/requirements.txt b/requirements.txt index f2f6e5a..22f8fad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,6 @@ scikit-learn==1.2.2 PyYAML==6.0 tqdm>=4.65.0 pandas>=2.0.0 -transformers>=4.30.0 datasets>=2.12.0 +huggingface-hub>=0.16.4 +transformers>=4.30.0 From 3161e38a3b36118897c96e120c518b6295773926 Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Mon, 2 Jun 2025 17:28:56 +0530 Subject: [PATCH 05/14] fix issue in train model --- train_multilingual.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/train_multilingual.py b/train_multilingual.py index 5f55e73..87cfa73 100644 --- a/train_multilingual.py +++ b/train_multilingual.py @@ -13,6 +13,22 @@ def setup_mixed_precision(): policy = tf.keras.mixed_precision.Policy('mixed_float16') tf.keras.mixed_precision.set_global_policy(policy) +class ExpandDimsLayer(tf.keras.layers.Layer): + def __init__(self, axis=-1, **kwargs): + super().__init__(**kwargs) + self.axis = axis + + def call(self, inputs): + return tf.expand_dims(inputs, axis=self.axis) + +class SqueezeLayer(tf.keras.layers.Layer): + def __init__(self, axis=-1, **kwargs): + super().__init__(**kwargs) + self.axis = axis + + def call(self, inputs): + return tf.squeeze(inputs, axis=self.axis) + def create_model(config, num_languages): """Create the model with support for multiple languages""" inputs = tf.keras.Input(shape=(None, config.speech_config['num_feature_bins'])) @@ -20,14 +36,15 @@ def create_model(config, num_languages): # CNN layers for filters, kernel in zip(config.model_config['filters'], config.model_config['kernel_size']): + x = ExpandDimsLayer(axis=-1)(x) x = tf.keras.layers.Conv2D( filters=filters, kernel_size=kernel, padding='same', activation='relu' - )(tf.expand_dims(x, axis=-1)) + )(x) x = tf.keras.layers.BatchNormalization()(x) - x = tf.squeeze(x, axis=-1) + x = SqueezeLayer(axis=-1)(x) # BiLSTM layers x = tf.keras.layers.Bidirectional( From 6510d06b2d3ec28e3011fcb4d46cd72b9be2bda6 Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Mon, 2 Jun 2025 17:38:33 +0530 Subject: [PATCH 06/14] fix issue in train model --- multilingual_dataset.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/multilingual_dataset.py b/multilingual_dataset.py index 0c5c2bf..e00139d 100644 --- a/multilingual_dataset.py +++ b/multilingual_dataset.py @@ -36,7 +36,20 @@ def __init__( self.load_datasets() def find_dataset_path(self, lang: str) -> Optional[str]: - """Find the dataset path in HuggingFace cache""" + """Find the dataset path in local hub/datasets structure or HuggingFace cache""" + # First try local data directory + local_dir = os.path.join("data", "fleurs", "hub", "datasets--google--fluers") + if os.path.exists(local_dir): + # Look for blob directories + blob_dirs = [d for d in os.listdir(local_dir) if d.startswith("blobs")] + if blob_dirs: + # Use the first blob directory found + dataset_path = os.path.join(local_dir, blob_dirs[0], lang) + if os.path.exists(dataset_path): + print(f"Found local dataset for {lang} at {dataset_path}") + return dataset_path + + # Fallback to HuggingFace cache cache_dir = os.path.expanduser("~/.cache/huggingface/datasets") dataset_dir = os.path.join(cache_dir, "google-fleurs", lang) From 18ca6a7f6d0984f83326d5419767d0e582cf80b5 Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Tue, 3 Jun 2025 09:45:36 +0530 Subject: [PATCH 07/14] fix issue --- configs/languages.json | 43 ++++++---------------- multilingual_dataset.py | 81 +++++++++++++++++++++++++---------------- 2 files changed, 62 insertions(+), 62 deletions(-) diff --git a/configs/languages.json b/configs/languages.json index eecf280..5e737e1 100644 --- a/configs/languages.json +++ b/configs/languages.json @@ -1,37 +1,18 @@ { "supported_languages": [ - "en", "es", "fr", "de", "it", "pt", "nl", "pl", "ru", "uk", - "ar", "hi", "bn", "ta", "te", "mr", "ur", "fa", "tr", "he", - "th", "vi", "id", "ms", "fil", "ja", "ko", "zh" + "be_by", + "bg_bg", + "bs_ba", + "ca_cs", + "cs_cz", + "cy_gb" ], "language_names": { - "en": "English", - "es": "Spanish", - "fr": "French", - "de": "German", - "it": "Italian", - "pt": "Portuguese", - "nl": "Dutch", - "pl": "Polish", - "ru": "Russian", - "uk": "Ukrainian", - "ar": "Arabic", - "hi": "Hindi", - "bn": "Bengali", - "ta": "Tamil", - "te": "Telugu", - "mr": "Marathi", - "ur": "Urdu", - "fa": "Persian", - "tr": "Turkish", - "he": "Hebrew", - "th": "Thai", - "vi": "Vietnamese", - "id": "Indonesian", - "ms": "Malay", - "fil": "Filipino", - "ja": "Japanese", - "ko": "Korean", - "zh": "Chinese" + "be_by": "Belarusian", + "bg_bg": "Bulgarian", + "bs_ba": "Bosnian", + "ca_cs": "Catalan", + "cs_cz": "Czech", + "cy_gb": "Welsh" } } \ No newline at end of file diff --git a/multilingual_dataset.py b/multilingual_dataset.py index e00139d..6cf4660 100644 --- a/multilingual_dataset.py +++ b/multilingual_dataset.py @@ -4,7 +4,7 @@ import pandas as pd from tqdm import tqdm from typing import List, Dict, Optional -from datasets import load_dataset, Dataset, load_from_disk +from datasets import load_dataset, Dataset, load_from_disk, concatenate_datasets from huggingface_hub import HfFileSystem from featurizers.speech_featurizers import NumpySpeechFeaturizer from configs.config import Config @@ -36,56 +36,75 @@ def __init__( self.load_datasets() def find_dataset_path(self, lang: str) -> Optional[str]: - """Find the dataset path in local hub/datasets structure or HuggingFace cache""" - # First try local data directory - local_dir = os.path.join("data", "fleurs", "hub", "datasets--google--fluers") - if os.path.exists(local_dir): - # Look for blob directories - blob_dirs = [d for d in os.listdir(local_dir) if d.startswith("blobs")] - if blob_dirs: - # Use the first blob directory found - dataset_path = os.path.join(local_dir, blob_dirs[0], lang) - if os.path.exists(dataset_path): - print(f"Found local dataset for {lang} at {dataset_path}") - return dataset_path - - # Fallback to HuggingFace cache - cache_dir = os.path.expanduser("~/.cache/huggingface/datasets") - dataset_dir = os.path.join(cache_dir, "google-fleurs", lang) + """Find the dataset path in the local datasets directory structure""" + # Dataset root directory + dataset_root = os.path.join("fleurs", "datasets", "google__fluers", lang) - if not os.path.exists(dataset_dir): + if not os.path.exists(dataset_root): print(f"Dataset directory not found for language {lang}") return None - # Look for the downloaded version - versions = [d for d in os.listdir(dataset_dir) if os.path.isdir(os.path.join(dataset_dir, d))] + # Look for version directories (e.g., 2.0.0) + versions = [d for d in os.listdir(dataset_root) if os.path.isdir(os.path.join(dataset_root, d))] if not versions: print(f"No dataset versions found for language {lang}") return None # Use the latest version latest_version = sorted(versions)[-1] - dataset_path = os.path.join(dataset_dir, latest_version) + dataset_path = os.path.join(dataset_root, latest_version) + + # Verify that necessary files exist + required_files = [ + "dataset_info.json", + "fleurs-train-00000-of-00003.arrow", + "fleurs-validation.arrow", + "fleurs-test.arrow" + ] - return dataset_path if os.path.exists(dataset_path) else None + for file in required_files: + if not os.path.exists(os.path.join(dataset_path, file)): + print(f"Missing required file {file} for language {lang}") + return None + + print(f"Found dataset for {lang} at {dataset_path}") + return dataset_path def load_local_dataset(self, lang: str) -> Optional[Dataset]: - """Load dataset from HuggingFace cache""" + """Load dataset from local directory""" try: # Find the dataset path dataset_path = self.find_dataset_path(lang) if dataset_path is None: return None - # Load the dataset - dataset = load_from_disk(dataset_path) - - # Get the appropriate split - if self.data_type in dataset: - return dataset[self.data_type] + # Load the dataset based on the data type + if self.data_type == "train": + # For training data, we need to load and concatenate multiple shards + shards = [] + for i in range(3): # We know there are 3 shards for training + shard_path = os.path.join(dataset_path, f"fleurs-train-{i:05d}-of-00003.arrow") + if os.path.exists(shard_path): + shard_dataset = Dataset.from_file(shard_path) + shards.append(shard_dataset) + else: + print(f"Warning: Missing shard {i} for language {lang}") + + if not shards: + print(f"No training shards found for language {lang}") + return None + + # Concatenate all shards + dataset = concatenate_datasets(shards) else: - print(f"Split {self.data_type} not found for language {lang}") - return None + # For validation and test, we have single files + file_path = os.path.join(dataset_path, f"fleurs-{self.data_type}.arrow") + if not os.path.exists(file_path): + print(f"Dataset file not found: {file_path}") + return None + dataset = Dataset.from_file(file_path) + + return dataset except Exception as e: print(f"Error loading dataset for language {lang}: {str(e)}") From bcf84f5253eb916519378eda8e451ae10fb5718c Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Tue, 3 Jun 2025 09:58:11 +0530 Subject: [PATCH 08/14] hash dir --- multilingual_dataset.py | 45 +++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/multilingual_dataset.py b/multilingual_dataset.py index 6cf4660..e7fbce9 100644 --- a/multilingual_dataset.py +++ b/multilingual_dataset.py @@ -52,23 +52,38 @@ def find_dataset_path(self, lang: str) -> Optional[str]: # Use the latest version latest_version = sorted(versions)[-1] - dataset_path = os.path.join(dataset_root, latest_version) + version_path = os.path.join(dataset_root, latest_version) - # Verify that necessary files exist - required_files = [ - "dataset_info.json", - "fleurs-train-00000-of-00003.arrow", - "fleurs-validation.arrow", - "fleurs-test.arrow" - ] - - for file in required_files: - if not os.path.exists(os.path.join(dataset_path, file)): - print(f"Missing required file {file} for language {lang}") + # Look for hash directory + try: + hash_dirs = [d for d in os.listdir(version_path) if os.path.isdir(os.path.join(version_path, d))] + if not hash_dirs: + print(f"No hash directory found for language {lang}") return None - - print(f"Found dataset for {lang} at {dataset_path}") - return dataset_path + + # Use the first hash directory found + hash_dir = hash_dirs[0] + dataset_path = os.path.join(version_path, hash_dir) + + # Verify that necessary files exist + required_files = [ + "dataset_info.json", + "fleurs-train-00000-of-00003.arrow", + "fleurs-validation.arrow", + "fleurs-test.arrow" + ] + + for file in required_files: + if not os.path.exists(os.path.join(dataset_path, file)): + print(f"Missing required file {file} for language {lang}") + return None + + print(f"Found dataset for {lang} at {dataset_path}") + return dataset_path + + except Exception as e: + print(f"Error accessing hash directory for language {lang}: {str(e)}") + return None def load_local_dataset(self, lang: str) -> Optional[Dataset]: """Load dataset from local directory""" From 6f669b253a49a9f63e87cd00befbe357da2bb1bb Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Tue, 3 Jun 2025 10:07:26 +0530 Subject: [PATCH 09/14] training_shards --- multilingual_dataset.py | 39 +++++++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/multilingual_dataset.py b/multilingual_dataset.py index e7fbce9..8664c5a 100644 --- a/multilingual_dataset.py +++ b/multilingual_dataset.py @@ -68,11 +68,21 @@ def find_dataset_path(self, lang: str) -> Optional[str]: # Verify that necessary files exist required_files = [ "dataset_info.json", - "fleurs-train-00000-of-00003.arrow", "fleurs-validation.arrow", "fleurs-test.arrow" ] + # Check for at least one training shard + train_shard_found = False + for file in os.listdir(dataset_path): + if file.startswith("fleurs-train-") and file.endswith(".arrow"): + train_shard_found = True + break + + if not train_shard_found: + print(f"No training shards found for language {lang}") + return None + for file in required_files: if not os.path.exists(os.path.join(dataset_path, file)): print(f"Missing required file {file} for language {lang}") @@ -85,6 +95,21 @@ def find_dataset_path(self, lang: str) -> Optional[str]: print(f"Error accessing hash directory for language {lang}: {str(e)}") return None + def get_num_shards(self, dataset_path: str) -> int: + """Determine the number of training shards in the dataset""" + shard_files = [f for f in os.listdir(dataset_path) if f.startswith("fleurs-train-") and f.endswith(".arrow")] + if not shard_files: + return 0 + + # Extract the total number of shards from the filename pattern + # Example: "fleurs-train-00000-of-00004.arrow" -> 4 + sample_file = shard_files[0] + try: + total_shards = int(sample_file.split("-of-")[1].split(".")[0]) + return total_shards + except: + return len(shard_files) + def load_local_dataset(self, lang: str) -> Optional[Dataset]: """Load dataset from local directory""" try: @@ -95,10 +120,16 @@ def load_local_dataset(self, lang: str) -> Optional[Dataset]: # Load the dataset based on the data type if self.data_type == "train": - # For training data, we need to load and concatenate multiple shards + # Determine number of shards + num_shards = self.get_num_shards(dataset_path) + if num_shards == 0: + print(f"No training shards found for language {lang}") + return None + + # Load all available shards shards = [] - for i in range(3): # We know there are 3 shards for training - shard_path = os.path.join(dataset_path, f"fleurs-train-{i:05d}-of-00003.arrow") + for i in range(num_shards): + shard_path = os.path.join(dataset_path, f"fleurs-train-{i:05d}-of-{num_shards:05d}.arrow") if os.path.exists(shard_path): shard_dataset = Dataset.from_file(shard_path) shards.append(shard_dataset) From ae3300a18f762f47daefadf9881213680196cfd5 Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Tue, 3 Jun 2025 10:35:28 +0530 Subject: [PATCH 10/14] changes in multilingual dataset loading ... --- multilingual_dataset.py | 46 ++++++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/multilingual_dataset.py b/multilingual_dataset.py index 8664c5a..757515a 100644 --- a/multilingual_dataset.py +++ b/multilingual_dataset.py @@ -131,8 +131,12 @@ def load_local_dataset(self, lang: str) -> Optional[Dataset]: for i in range(num_shards): shard_path = os.path.join(dataset_path, f"fleurs-train-{i:05d}-of-{num_shards:05d}.arrow") if os.path.exists(shard_path): - shard_dataset = Dataset.from_file(shard_path) - shards.append(shard_dataset) + try: + # Load dataset with memory mapping to avoid PyArrow conversion issues + shard_dataset = Dataset.from_file(shard_path, keep_in_memory=False) + shards.append(shard_dataset) + except Exception as e: + print(f"Error loading shard {i} for language {lang}: {str(e)}") else: print(f"Warning: Missing shard {i} for language {lang}") @@ -141,16 +145,25 @@ def load_local_dataset(self, lang: str) -> Optional[Dataset]: return None # Concatenate all shards - dataset = concatenate_datasets(shards) + try: + dataset = concatenate_datasets(shards) + return dataset + except Exception as e: + print(f"Error concatenating shards for language {lang}: {str(e)}") + return None else: # For validation and test, we have single files file_path = os.path.join(dataset_path, f"fleurs-{self.data_type}.arrow") if not os.path.exists(file_path): print(f"Dataset file not found: {file_path}") return None - dataset = Dataset.from_file(file_path) - - return dataset + try: + # Load dataset with memory mapping + dataset = Dataset.from_file(file_path, keep_in_memory=False) + return dataset + except Exception as e: + print(f"Error loading {self.data_type} dataset for language {lang}: {str(e)}") + return None except Exception as e: print(f"Error loading dataset for language {lang}: {str(e)}") @@ -205,23 +218,32 @@ def get_batch_generator(self, batch_size: int): for i in range(0, len(indices), batch_size): batch_indices = indices[i:i + batch_size] - batch_data = dataset.select(batch_indices) features_list = [] labels = [] - for item in batch_data: + # Get batch items one by one to avoid PyArrow conversion issues + for idx in batch_indices: try: - # Process audio - audio_data = item['audio']['array'] + # Get item and immediately convert to dict + item = dataset[idx] + if not isinstance(item, dict): + item = dict(item) + + # Process audio data + audio_array = item['audio']['array'] + if isinstance(audio_array, (list, tuple)): + audio_array = np.array(audio_array) sampling_rate = item['audio']['sampling_rate'] - features = self.prepare_audio(audio_data, sampling_rate) + + # Extract features + features = self.prepare_audio(audio_array, sampling_rate) features_list.append(features) # Get language label labels.append(self.language_to_id[lang]) except Exception as e: - print(f"Error processing item in {lang}: {str(e)}") + print(f"Error processing item {idx} in {lang}: {str(e)}") continue if not features_list: From a20c56aab8a22b75a8ffd6a8e3e70c4c17a11737 Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Tue, 3 Jun 2025 10:40:21 +0530 Subject: [PATCH 11/14] changes in multilingual datasets script --- multilingual_dataset.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/multilingual_dataset.py b/multilingual_dataset.py index 757515a..fd73319 100644 --- a/multilingual_dataset.py +++ b/multilingual_dataset.py @@ -132,8 +132,8 @@ def load_local_dataset(self, lang: str) -> Optional[Dataset]: shard_path = os.path.join(dataset_path, f"fleurs-train-{i:05d}-of-{num_shards:05d}.arrow") if os.path.exists(shard_path): try: - # Load dataset with memory mapping to avoid PyArrow conversion issues - shard_dataset = Dataset.from_file(shard_path, keep_in_memory=False) + # Load dataset shard + shard_dataset = Dataset.from_file(shard_path) shards.append(shard_dataset) except Exception as e: print(f"Error loading shard {i} for language {lang}: {str(e)}") @@ -158,8 +158,7 @@ def load_local_dataset(self, lang: str) -> Optional[Dataset]: print(f"Dataset file not found: {file_path}") return None try: - # Load dataset with memory mapping - dataset = Dataset.from_file(file_path, keep_in_memory=False) + dataset = Dataset.from_file(file_path) return dataset except Exception as e: print(f"Error loading {self.data_type} dataset for language {lang}: {str(e)}") @@ -222,19 +221,30 @@ def get_batch_generator(self, batch_size: int): features_list = [] labels = [] - # Get batch items one by one to avoid PyArrow conversion issues + # Process items one at a time for idx in batch_indices: try: - # Get item and immediately convert to dict + # Get item and ensure it's a dictionary item = dataset[idx] if not isinstance(item, dict): item = dict(item) # Process audio data - audio_array = item['audio']['array'] + audio_data = item['audio'] + if isinstance(audio_data, dict): + audio_array = audio_data.get('array') + sampling_rate = audio_data.get('sampling_rate') + else: + print(f"Unexpected audio data format for item {idx} in {lang}") + continue + + if audio_array is None or sampling_rate is None: + print(f"Missing audio data or sampling rate for item {idx} in {lang}") + continue + + # Convert audio array if needed if isinstance(audio_array, (list, tuple)): audio_array = np.array(audio_array) - sampling_rate = item['audio']['sampling_rate'] # Extract features features = self.prepare_audio(audio_array, sampling_rate) From 741b840b132af4e5bab4d3e58b3d68c490386f42 Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Tue, 3 Jun 2025 11:03:36 +0530 Subject: [PATCH 12/14] changes in stft --- featurizers/speech_featurizers.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/featurizers/speech_featurizers.py b/featurizers/speech_featurizers.py index fb8f83b..70f277a 100644 --- a/featurizers/speech_featurizers.py +++ b/featurizers/speech_featurizers.py @@ -232,10 +232,18 @@ def shape(self) -> list: return [None, self.num_feature_bins, channel_dim] def stft(self, signal): + if len(signal) < self.nfft: + print(f"[Skip] Signal too short for STFT: len({len(signal)}) < nfft = {self.nfft}") + return np.zeros((self.nfft//2 + 1 ,1)) + max_len = 10 * self.sample_rate + if len(signal) > max_len: + print(f"[Truncate] Signal too long for STFT: len({len(signal)}) > max_len = {max_len}") + signal = signal[:max_len] return np.square( np.abs(librosa.core.stft(signal, n_fft=self.nfft, hop_length=self.frame_step, win_length=self.frame_length, center=True, window="hann"))) + def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0): return librosa.power_to_db(S, ref=ref, amin=amin, top_db=top_db) From 84a4eb4cefbddfe1c8385c3a6f1e0809813f9f68 Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Tue, 3 Jun 2025 11:20:52 +0530 Subject: [PATCH 13/14] changes in log mel --- featurizers/speech_featurizers.py | 64 +++++++++++++++++++------------ multilingual_dataset.py | 46 ++++++++++++++++------ 2 files changed, 74 insertions(+), 36 deletions(-) diff --git a/featurizers/speech_featurizers.py b/featurizers/speech_featurizers.py index 70f277a..d46cfe6 100644 --- a/featurizers/speech_featurizers.py +++ b/featurizers/speech_featurizers.py @@ -317,15 +317,46 @@ def compute_mfcc(self, signal: np.ndarray) -> np.ndarray: return mfcc.T def compute_log_mel_spectrogram(self, signal: np.ndarray) -> np.ndarray: - S = self.stft(signal) - - mel = librosa.filters.mel(self.sample_rate, self.nfft, - n_mels=self.num_feature_bins, - fmin=0.0, fmax=int(self.sample_rate / 2)) - - mel_spectrogram = np.dot(S.T, mel.T) - - return self.power_to_db(mel_spectrogram) + """Compute log mel spectrogram with proper error handling for long signals""" + try: + # Handle long signals + max_len = 320000 # Maximum length for STFT + if len(signal) > max_len: + print(f"[Truncate] Signal too long: len({len(signal)}) > max_len = {max_len}") + # Take the center portion + start = (len(signal) - max_len) // 2 + signal = signal[start:start + max_len] + + # Compute STFT + S = self.stft(signal) + + # Create mel filterbank if not already created + if self.mel_filter is None: + self.mel_filter = librosa.filters.mel( + sr=self.sample_rate, + n_fft=self.nfft, + n_mels=self.num_feature_bins, + fmin=0.0, + fmax=int(self.sample_rate / 2) + ) + + # Apply mel filterbank + mel_spectrogram = np.dot(S.T, self.mel_filter.T) + + # Convert to log scale + log_mel_spec = self.power_to_db(mel_spectrogram) + + # Handle any NaN or Inf values + if np.isnan(log_mel_spec).any() or np.isinf(log_mel_spec).any(): + print("Warning: NaN or Inf values in log mel spectrogram, replacing with zeros") + log_mel_spec = np.nan_to_num(log_mel_spec, 0) + + return log_mel_spec + + except Exception as e: + print(f"Error computing log mel spectrogram: {str(e)}") + # Return empty spectrogram with correct shape + return np.zeros((1, self.num_feature_bins)) def compute_log_gammatone_spectrogram(self, signal: np.ndarray) -> np.ndarray: S = self.stft(signal) @@ -418,21 +449,6 @@ def tf_extract(self, signal: tf.Tensor) -> tf.Tensor: return features - def compute_log_mel_spectrogram(self, signal): - spectrogram = self.stft(signal) - if self.mel_filter is None: - linear_to_weight_matrix = tf.signal.linear_to_mel_weight_matrix( - num_mel_bins=self.num_feature_bins, - num_spectrogram_bins=spectrogram.shape[-1], - sample_rate=self.sample_rate, - lower_edge_hertz=0.0, upper_edge_hertz=(self.sample_rate / 2) - ) - else: - linear_to_weight_matrix = self.mel_filter - - mel_spectrogram = tf.tensordot(spectrogram, linear_to_weight_matrix, 1) - return self.power_to_db(mel_spectrogram) - def compute_spectrogram(self, signal): S = self.stft(signal) spectrogram = self.power_to_db(S) diff --git a/multilingual_dataset.py b/multilingual_dataset.py index fd73319..69a7737 100644 --- a/multilingual_dataset.py +++ b/multilingual_dataset.py @@ -189,18 +189,40 @@ def load_datasets(self): def prepare_audio(self, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray: """Process audio data to extract features using NumPy-based extraction""" - if sampling_rate != self.config.speech_config['sample_rate']: - # Resample if necessary - import librosa - audio_data = librosa.resample( - audio_data, - orig_sr=sampling_rate, - target_sr=self.config.speech_config['sample_rate'] - ) - - # Extract features using NumPy-based feature extraction - features = self.speech_featurizer.extract(audio_data) - return features + try: + if sampling_rate != self.config.speech_config['sample_rate']: + # Resample if necessary + import librosa + audio_data = librosa.resample( + audio_data, + orig_sr=sampling_rate, + target_sr=self.config.speech_config['sample_rate'] + ) + + # Trim silence + audio_data, _ = librosa.effects.trim(audio_data, top_db=30) + + # Handle long audio by splitting into chunks if necessary + max_samples = 320000 # Maximum samples for STFT + if len(audio_data) > max_samples: + # Take the center portion of the audio + start = (len(audio_data) - max_samples) // 2 + audio_data = audio_data[start:start + max_samples] + + # Extract features using the speech featurizer + features = self.speech_featurizer.extract(audio_data) + + # Ensure the features are in the correct range + if np.isnan(features).any() or np.isinf(features).any(): + print("Warning: NaN or Inf values in features, replacing with zeros") + features = np.nan_to_num(features, 0) + + return features + + except Exception as e: + print(f"Error in prepare_audio: {str(e)}") + # Return empty features with correct shape as fallback + return np.zeros((1, self.config.speech_config['num_feature_bins'])) def get_batch_generator(self, batch_size: int): """Generate batches of data""" From 22fbfe0f3b65d757cedd3018f6b93a5628e5e1a2 Mon Sep 17 00:00:00 2001 From: PRANJAL BHARTI <88613437+Prashu-10@users.noreply.github.com> Date: Tue, 3 Jun 2025 11:34:43 +0530 Subject: [PATCH 14/14] changes in multilingual and featurizer --- featurizers/speech_featurizers.py | 10 ++++++---- multilingual_dataset.py | 13 ++++--------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/featurizers/speech_featurizers.py b/featurizers/speech_featurizers.py index d46cfe6..aca7324 100644 --- a/featurizers/speech_featurizers.py +++ b/featurizers/speech_featurizers.py @@ -235,10 +235,12 @@ def stft(self, signal): if len(signal) < self.nfft: print(f"[Skip] Signal too short for STFT: len({len(signal)}) < nfft = {self.nfft}") return np.zeros((self.nfft//2 + 1 ,1)) - max_len = 10 * self.sample_rate + max_len = 320000 # Increased from 160000 to match other parts of the code if len(signal) > max_len: print(f"[Truncate] Signal too long for STFT: len({len(signal)}) > max_len = {max_len}") - signal = signal[:max_len] + # Take the center portion of the signal + start = (len(signal) - max_len) // 2 + signal = signal[start:start + max_len] return np.square( np.abs(librosa.core.stft(signal, n_fft=self.nfft, hop_length=self.frame_step, win_length=self.frame_length, center=True, window="hann"))) @@ -319,7 +321,7 @@ def compute_mfcc(self, signal: np.ndarray) -> np.ndarray: def compute_log_mel_spectrogram(self, signal: np.ndarray) -> np.ndarray: """Compute log mel spectrogram with proper error handling for long signals""" try: - # Handle long signals + # Handle long signals - using the same max_len as stft max_len = 320000 # Maximum length for STFT if len(signal) > max_len: print(f"[Truncate] Signal too long: len({len(signal)}) > max_len = {max_len}") @@ -328,7 +330,7 @@ def compute_log_mel_spectrogram(self, signal: np.ndarray) -> np.ndarray: signal = signal[start:start + max_len] # Compute STFT - S = self.stft(signal) + S = self.stft(signal) # stft will handle any remaining length issues # Create mel filterbank if not already created if self.mel_filter is None: diff --git a/multilingual_dataset.py b/multilingual_dataset.py index 69a7737..d9eec3d 100644 --- a/multilingual_dataset.py +++ b/multilingual_dataset.py @@ -190,11 +190,13 @@ def load_datasets(self): def prepare_audio(self, audio_data: np.ndarray, sampling_rate: int) -> np.ndarray: """Process audio data to extract features using NumPy-based extraction""" try: + # Import librosa here to ensure it's available + import librosa + if sampling_rate != self.config.speech_config['sample_rate']: # Resample if necessary - import librosa audio_data = librosa.resample( - audio_data, + y=audio_data, orig_sr=sampling_rate, target_sr=self.config.speech_config['sample_rate'] ) @@ -202,13 +204,6 @@ def prepare_audio(self, audio_data: np.ndarray, sampling_rate: int) -> np.ndarra # Trim silence audio_data, _ = librosa.effects.trim(audio_data, top_db=30) - # Handle long audio by splitting into chunks if necessary - max_samples = 320000 # Maximum samples for STFT - if len(audio_data) > max_samples: - # Take the center portion of the audio - start = (len(audio_data) - max_samples) // 2 - audio_data = audio_data[start:start + max_samples] - # Extract features using the speech featurizer features = self.speech_featurizer.extract(audio_data)