diff --git a/aiu_fms_testing_utils/utils/__init__.py b/aiu_fms_testing_utils/utils/__init__.py
index 92997067..3b3473fc 100644
--- a/aiu_fms_testing_utils/utils/__init__.py
+++ b/aiu_fms_testing_utils/utils/__init__.py
@@ -9,6 +9,7 @@
# Third Party
from aiu_fms_testing_utils.utils.aiu_setup import dprint
from fms.utils.tokenizers import BaseTokenizer
+from fms.utils.generation import pad_input_ids
import torch
import torch.nn as nn
@@ -166,3 +167,43 @@ def sample_squad_v2_qa_requests(
prompt_length_max,
seed,
)
+
+def prepare_inputs(batch_size, seq_length, tokenizer, ds_path, seed=0, ds_type="sharegpt"):
+ """
+ Prepare input IDs and padding kwargs for a batch of questions.
+
+ Args:
+ batch_size (int): The number of questions in the batch.
+ seq_length (int): The maximum length of the input sequence.
+ tokenizer (Tokenizer): A tokenizer object to tokenize the questions.
+ ds_path (str): The path to the dataset file.
+ seed (int, optional): The random seed for reproducibility. Defaults to 0.
+ ds_type (str, optional): The type of dataset to use. Can be "sharegpt" or any other supported dataset type. Defaults to "sharegpt".
+
+ Returns:
+ tuple: A tuple containing the input IDs and padding kwargs.
+ """
+ if not "sharegpt" in ds_type:
+ prompts_and_sizes = sample_squad_v2_qa_requests(
+ ds_path,
+ batch_size,
+ tokenizer,
+ int(seq_length / 2),
+ seq_length,
+ seed,
+ )
+ else:
+ prompts_and_sizes = sample_sharegpt_requests(
+ ds_path,
+ batch_size,
+ tokenizer,
+ int(seq_length / 2),
+ seq_length,
+ seed,
+ )
+ prompt_list = []
+ for prompt, _ in prompts_and_sizes:
+ prompt_list.append(ids_for_prompt(prompt, tokenizer))
+
+ input_ids, padding_kwargs = pad_input_ids(prompt_list, min_pad_length=seq_length)
+ return input_ids, padding_kwargs
diff --git a/aiu_fms_testing_utils/utils/metrics_utils.py b/aiu_fms_testing_utils/utils/metrics_utils.py
new file mode 100644
index 00000000..9f011786
--- /dev/null
+++ b/aiu_fms_testing_utils/utils/metrics_utils.py
@@ -0,0 +1,72 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+
+def abs_diff_linalg_norm(res_vector):
+ """
+ Calculates the Euclidean norm (also known as the L2 norm) of a given array res_vector. This is equivalent to finding the square
+ root of the sum of the squares of all the elements in the array. It's a fundamental operation in linear algebra and is often used
+ to measure the "length" or "magnitude" of a vector. More at https://numpy.org/devdocs/reference/generated/numpy.linalg.norm.html
+ Args:
+ res_vector (list): The list of abs diff
+
+ Returns:
+ float: "magnitude" of the diff vector.
+ """
+ return np.linalg.norm(res_vector)
+
+def list_mean(val_list):
+ """
+ Calculates the mean for all the values in a given list.
+ Args:
+ val_list (list): The list of values
+
+ Returns:
+ float: mean value calculated.
+ """
+ return np.mean(val_list)
+
+def tensor_abs_diff(tensor1, tensor2):
+ """
+ Calculate the absolute difference between two tensors.
+
+ Args:
+ tensor1 (torch.Tensor): The first input tensor.
+ tensor2 (torch.Tensor): The second input tensor.
+
+ Returns:
+ torch.Tensor: The absolute difference tensor.
+
+ Example:
+ >>> tensor1 = torch.tensor([1, 2, 3])
+ >>> tensor2 = torch.tensor([4, 5, 6])
+ >>> abs_diff(tensor1, tensor2)
+ torch.tensor([3, 3, 3])
+ """
+ abs_diff = torch.abs(tensor1 - tensor2)
+ return abs_diff
+
+def tensor_cos_sim(tensor1, tensor2):
+ """
+ Computes the cosine similarity between two tensors.
+
+ Args:
+ tensor1 (torch.Tensor): The first input tensor.
+ tensor2 (torch.Tensor): The second input tensor.
+
+ Returns:
+ torch.Tensor: The cosine similarity between the two input tensors.
+
+ Example:
+ >>> import torch
+ >>> tensor1 = torch.randn(3, 5)
+ >>> tensor2 = torch.randn(3, 5)
+ >>> sim = cos_sim(tensor1, tensor2)
+ >>> print(sim)
+ """
+ cos = nn.CosineSimilarity(dim=-1)
+ tensor1[tensor1 == 0.0] = 1e-6
+ tensor2[tensor2 == 0.0] = 1e-6
+ cos_sim = cos(tensor1, tensor2)
+ return cos_sim
\ No newline at end of file
diff --git a/scripts/generate_layers_metrics.py b/scripts/generate_layers_metrics.py
new file mode 100644
index 00000000..48d41f62
--- /dev/null
+++ b/scripts/generate_layers_metrics.py
@@ -0,0 +1,441 @@
+import os
+import time
+import logging
+import argparse
+
+import itertools
+import torch
+import torch.nn as nn
+
+from fms.utils import tokenizers
+from fms.models import get_model
+from fms.utils.generation import generate
+
+from aiu_fms_testing_utils.testing.validation import get_default_validation_prefix
+
+from aiu_fms_testing_utils.utils import prepare_inputs
+from aiu_fms_testing_utils.utils.metrics_utils import tensor_abs_diff, tensor_cos_sim
+
+
+logger = logging.getLogger(__name__)
+LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO').upper()
+logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s %(message)s")
+
+parser = argparse.ArgumentParser(
+ description="Script to generate the model's metrics by layer"
+)
+parser.add_argument(
+ "--architecture",
+ type=str,
+ help="The model architecture Eg.: hf_pretrained",
+)
+parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="The model variants (configuration) to benchmark. E.g. ibm-granite/granite-3.2-8b-instruct",
+)
+parser.add_argument(
+ "--model_path",
+ type=str,
+ help="Paths to the directory containing model's weights (.pth files sharded by tensor parallel rank, not HF weights)",
+)
+parser.add_argument(
+ "--mode",
+ choices=["generate", "model-forward"],
+ default="generate",
+ required=True,
+ help="Sets the output generation mode."
+)
+parser.add_argument(
+ "--batch_sizes",
+ type=str,
+ default="1",
+ required=True,
+ help="Batch sizes separated by comma. Eg.: 1,2"
+)
+parser.add_argument(
+ "--seq_lengths",
+ type=str,
+ default="64",
+ required=True,
+ help="Sequence lengths separated by comma. Eg.: 64,2048"
+)
+parser.add_argument(
+ "--max_new_tokens",
+ type=str,
+ default="128",
+ required=True,
+ help="Max number of generated tokens separated by comma. Eg.: 64,128"
+)
+parser.add_argument(
+ "--output_path",
+ type=str,
+ default="/tmp/output",
+ help="Path to save output files"
+)
+parser.add_argument(
+ "--sharegpt_path",
+ type=str,
+ default=os.path.expanduser("~/share_gpt.json"),
+ help="Path to sharegpt data json"
+)
+
+args = parser.parse_args()
+mode = args.mode
+output_path = args.output_path
+sharegpt_path = args.sharegpt_path
+
+common_model_paths = args.model_path if args.model_path else args.variant
+if isinstance(common_model_paths, str):
+ common_model_paths = [str(bs) for bs in common_model_paths.split(",")]
+
+# pass custom common batch sizes as a comma separated str of ints
+common_batch_sizes = args.batch_sizes
+if isinstance(common_batch_sizes, str):
+ common_batch_sizes = [int(bs) for bs in common_batch_sizes.split(",")]
+
+# pass custom common seq lengths as a comma separated str of ints
+common_seq_lengths = args.seq_lengths
+if isinstance(common_seq_lengths, str):
+ common_seq_lengths = [int(sl) for sl in common_seq_lengths.split(",")]
+
+# pass custom common max new tokens as a comma separated str of ints
+common_max_new_tokens = args.max_new_tokens
+if isinstance(common_max_new_tokens, str):
+ common_max_new_tokens = [int(mnt) for mnt in common_max_new_tokens.split(",")]
+
+common_shapes = list(
+ itertools.product(
+ common_model_paths,
+ common_batch_sizes,
+ common_seq_lengths,
+ common_max_new_tokens,
+ )
+)
+
+generate_iters = 0
+
+def __infer_layer(model, max_len, device, max_new_tokens, batch_size, tokenizer):
+ """
+ Infer a model with registered layer hooks using generated inputs.
+
+ Args:
+ model (nn.Module): The model to infer.
+ max_len (int): The maximum length of the input sequence.
+ device (str): The device to use for inference.
+ max_new_tokens (int): The maximum number of new tokens to generate.
+ batch_size (int): The batch size for inference.
+ tokenizer (Tokenizer): The tokenizer to use for encoding inputs.
+
+ Returns:
+ torch.Tensor: The inferred model's layers output.
+ """
+
+ do_sample = False
+ use_cache = True
+
+ prompts = prepare_inputs(batch_size, max_len, tokenizer, sharegpt_path)
+ ids, pad_input_ids = prompts
+
+ if "cuda" in device:
+ ids = ids.to("cuda")
+
+ if hasattr(model.config, "ntk_scaling") and model.config.ntk_scaling:
+ max_seq_len = max(max_len, model.config.max_expected_seq_len)
+ else:
+ # without ntk scaling, extending the seq length too far gives bogus results.
+ max_seq_len = model.config.max_expected_seq_len
+
+ if "generate" in mode:
+ with torch.no_grad():
+ result = generate(
+ model,
+ ids,
+ max_new_tokens=max_new_tokens,
+ use_cache=use_cache,
+ do_sample=do_sample,
+ max_seq_len=max_seq_len,
+ timing="e2e",
+ eos_token_id=None,
+ contiguous_cache=True,
+ extra_kwargs={},
+ )
+ result, timings = result
+ logger.info(f"Generation completed: Result len is {len(result)}")
+ if len(result.shape) == 1:
+ result = result.unsqueeze(0)
+ else:
+ result = model.forward(
+ ids,
+ use_cache=use_cache
+ )
+ logger.info(f"Model forward completed: Result len is {len(result)}")
+
+def __register_call_layers(model, batch_size, device, seq_length, max_new_tokens, tokenizer):
+ """
+ This function registers hooks on the model to track the forward pass of each layer.
+ It returns a list of tuples containing the name and output of each layer in the model.
+
+ Args:
+ model (nn.Module): The PyTorch model to be analyzed.
+ batch_size (int): The batch size used for inference.
+ device (torch.device): The device on which the model is running.
+ seq_length (int): The maximum sequence length of the input data.
+ max_new_tokens (int): The maximum number of new tokens to be generated during inference.
+ tokenizer (Tokenizer): The tokenizer used for tokenization.
+
+ Returns:
+ list: A list of tuples containing the name and output of each layer in the model.
+ """
+ layer_stack = {}
+ pt_compile_model_time = time.time()
+
+ module_depth = {}
+ module_name = {}
+
+ def register_depths(module, current_depth=0, name='model'):
+ module_depth[module] = current_depth
+ module_name[module] = name
+ parent=name
+ # if we are dealing with array of layers
+ array_layers = all(key.isdigit() for key in module._modules.keys())
+ for name, child in module._modules.items():
+ if array_layers:
+ register_depths(child, current_depth + 1, parent+'['+name+']')
+ else:
+ register_depths(child, current_depth + 1, parent+'.'+name)
+
+ register_depths(model)
+
+ def wrap_forward(layer):
+ original_forward = layer.forward
+
+ def safe_forward(*args, **kwargs):
+ try:
+ return original_forward(*args, **kwargs)
+ except (RuntimeError,TypeError) as e:
+ logger.error(f"Error in {layer.__class__.__name__}: {e}")
+ return torch.zeros_like(args[0]) if args else None
+ layer.forward = safe_forward
+
+
+ hooks = []
+ def pre_hook_fn(module, input):
+ depth = module_depth.get(module, 0)
+ layer_name = module_name.get(module, 0)
+ prefix = 'β ' * depth
+ if len(input) == 0: return
+ input_shape_str = f"[{', '.join(map(str, input[0].shape))}]"
+ input_type = str(input[0].dtype)
+ if module.parameters() == None: return
+ param_size = sum(p.numel() for p in module.parameters() if p.requires_grad)
+ param_size_str = f"{param_size:,}" if param_size > 0 else "--"
+ logger.info(f"{prefix}ββ{layer_name}() -> {module.__class__.__name__} : | Input(arg): {input_shape_str} | {input_type} | Params: {param_size_str}")
+ wrap_forward(module)
+ # save input for later use with outputs
+ module._debug_input = input
+
+ def post_hook_fn(module, input, output):
+ layer_name = module_name.get(module, 0)
+ # Save inputs and outputs
+ tmp = {}
+ if hasattr(module, '_debug_input'):
+ global generate_iters
+ generate_iters += 1
+ layer_name = f"{layer_name}.iter-{generate_iters}" if layer_name in layer_stack.keys() else layer_name
+ tmp[layer_name] = output
+ layer_stack.update(tmp)
+ # Clean up
+ delattr(module, '_debug_input')
+
+ for name, layer in model.named_modules():
+ hooks.append(layer.register_forward_pre_hook(pre_hook_fn))
+ hooks.append(layer.register_forward_hook(post_hook_fn))
+
+
+ __infer_layer(model= model, max_len=seq_length,
+ device=device, max_new_tokens=max_new_tokens,
+ batch_size=batch_size, tokenizer=tokenizer)
+
+ for hook in hooks:
+ hook.remove()
+
+ pt_compile_model_time = time.time() - pt_compile_model_time
+ logger.info(f"PT compile complete, took {pt_compile_model_time:.3f}s")
+
+ return layer_stack
+
+def get_metric_values(metric_list):
+ if isinstance(metric_list, list):
+ # shape of the first tensor to be printed in file
+ metric_shape = metric_list[0].shape
+ metric_list_res = torch.stack(metric_list).flatten().tolist()
+ else:
+ metric_shape = metric_list.shape
+ metric_list_res = metric_list.flatten().tolist()
+
+ return metric_list_res, metric_shape
+
+def write_csv(values, path, metric, gpu_layer_shape, cpu_layer_shape, output_shape):
+ """
+ Write values to a CSV file at the given path.
+
+ Args:
+ values (list or float): A list of values to be written to the CSV file.
+ If `values` is a single float, it will be written as a scalar value in the first column of the CSV file.
+ path (str): The path to the CSV file to write to.
+ metric (str): The name of the metric being evaluated.
+ gpu_layer_shape (tuple): The shape of the GPU layer used for training.
+ cpu_layer_shape (tuple): The shape of the CPU layer used for training.
+ output_shape (tuple): The shape of the output generated by the model.
+
+ Returns:
+ None
+ """
+ with open(path, 'w') as f:
+ f.write(f'{metric}\n')
+ f.write(f'GPU shape {gpu_layer_shape} CPU shape {cpu_layer_shape}\n')
+ f.write(f'Metric shape {output_shape}\n')
+ if not isinstance(values, float):
+ for t in values:
+ f.write(f"{t}\n")
+ else:
+ f.write(f"{values}\n")
+ f.close()
+
+def generate_layers_metrics(model_path, batch_size, seq_length, max_new_tokens):
+ """
+ Generate metrics for layers in a given model.
+
+ Args:
+ model_path (str): The path to the Hugging Face model.
+ batch_size (int): The batch size used for inference.
+ seq_length (int): The sequence length used for inference.
+ max_new_tokens (int): The maximum number of new tokens allowed for generation.
+
+ Returns:
+ None
+ """
+
+ torch.manual_seed(42)
+ os.environ["COMPILATION_MODE"] = "offline_decoder"
+
+ if "HF_HOME" not in os.environ:
+ os.environ["HF_HOME"] = "/tmp/models/hf_cache"
+
+ model_path_kwargs = {"variant": model_path} if args.variant else {"model_path": model_path}
+ micro_model_kwargs = {"architecture": args.architecture}
+
+ get_model_kwargs = {
+ **model_path_kwargs,
+ **micro_model_kwargs,
+ }
+
+ tokenizer = tokenizers.get_tokenizer(model_path)
+
+ # prepare the cpu model
+ validation_model = get_model(
+ device_type="cpu",
+ data_type=torch.float32,
+ fused_weights=False,
+ **get_model_kwargs,
+ )
+
+ # prepare the cuda model
+ validation_model_cuda = get_model(
+ device_type="cuda",
+ data_type=torch.float16,
+ fused_weights=False,
+ **get_model_kwargs,
+ )
+
+ layer_stack_cpu = __register_call_layers(model=validation_model,
+ batch_size=batch_size,
+ device="cpu",
+ seq_length=seq_length, max_new_tokens=max_new_tokens,
+ tokenizer=tokenizer)
+
+ global generate_iters
+ generate_iters = 0
+ logger.info(f"Finished registering CPU layers")
+
+ layer_stack_cuda = __register_call_layers(model=validation_model_cuda,
+ batch_size=batch_size,
+ device="cuda",
+ seq_length=seq_length, max_new_tokens=max_new_tokens,
+ tokenizer=tokenizer)
+
+ assert len(layer_stack_cuda.keys()) == len(layer_stack_cpu.keys())
+
+ for layer_key, output_val in layer_stack_cuda.items():
+
+ tensor_cpu_out = None
+ tensor_cuda_out = None
+
+ if layer_key in layer_stack_cpu.keys():
+ cpu_output = layer_stack_cpu[layer_key]
+ cuda_output = output_val
+ logger.info(f"Comparing CPU and GPU Layer {layer_key} output")
+
+ if type(cpu_output) is tuple and type(cuda_output) is tuple:
+ cos_sim = []
+ abs_diff = []
+ if len(cpu_output) < 2 and len(cpu_output[-1]) == 1:
+ # Projection layers (often called "query," "key," and "value" projections) are used to transform the input embeddings
+ # into separate query, key, and value vectors. They have tuple outputs, with more than 2 tensors - this path compares this type of output;
+ # In case of head layers, the last item of the tuple is a list of tensors with the same lenght as the
+ # number of layers in the model. The else path compares this other case.
+ tensor_cuda_out = cuda_output[-1]
+ tensor_cpu_out = cpu_output[-1]
+ for i in range(len(cpu_output)):
+ logger.debug(f"inputs: {cuda_output[i].shape} {cpu_output[i].to('cuda').shape}")
+ cos_sim.append(tensor_cos_sim(cuda_output[i], cpu_output[i].to('cuda')))
+ logger.debug(f"cos_sim output:{tensor_cos_sim(cuda_output[i], cpu_output[i].to('cuda')).shape}")
+ abs_diff.append(tensor_abs_diff(cuda_output[i], cpu_output[i].to('cuda')))
+ else:
+ head_tensor_cpu = cpu_output[-1]
+ head_tensor_gpu = cuda_output[-1]
+ for i in range(len(head_tensor_gpu)):
+ if isinstance(head_tensor_gpu[i], tuple):
+ for j in range(len(head_tensor_gpu[i])):
+ tensor_cuda_out = head_tensor_gpu[i][j]
+ tensor_cpu_out = head_tensor_cpu[i][j]
+ logger.debug(f"inputs: {head_tensor_gpu[i][j].shape} {head_tensor_cpu[i][j].to('cuda').shape}")
+ cos_sim.append(tensor_cos_sim(head_tensor_cpu[i][j].to('cuda'), head_tensor_gpu[i][j]))
+ logger.debug(f"cos_sim output:{tensor_cos_sim(head_tensor_cpu[i][j].to('cuda'), head_tensor_gpu[i][j]).shape}")
+ abs_diff.append(tensor_abs_diff(head_tensor_cpu[i][j].to('cuda'), head_tensor_gpu[i][j]))
+ else:
+ tensor_cuda_out = head_tensor_gpu[i]
+ tensor_cpu_out = head_tensor_cpu[i]
+ logger.debug(f"inputs: {head_tensor_gpu[i].shape} {head_tensor_cpu[i].to('cuda').shape}")
+ cos_sim.append(tensor_cos_sim(head_tensor_cpu[i].to('cuda'), head_tensor_gpu[i]))
+ logger.debug(f"cos_sim output:{tensor_cos_sim(head_tensor_cpu[i].to('cuda'), head_tensor_gpu[i]).shape}")
+ abs_diff.append(tensor_abs_diff(head_tensor_cpu[i].to('cuda'), head_tensor_gpu[i]))
+ else:
+ tensor_cpu_out = cpu_output.to('cuda')
+ tensor_cuda_out = cuda_output
+ abs_diff = tensor_abs_diff(tensor_cpu_out, cuda_output)
+ cos_sim = tensor_cos_sim(tensor_cpu_out, cuda_output)
+
+ prefix = get_default_validation_prefix(model_path, max_new_token, batch_size, seq_length, 'float16')
+ layer_name = str(layer_key).replace('[','').replace(']', '')
+
+ abs_diff_path = os.path.join(output_path, f"{prefix}--{layer_name}.abs_diff.csv")
+ cos_sim_path = os.path.join(output_path, f"{prefix}--{layer_name}.cos_sim.csv")
+
+ cos_sim_res, cos_shape = get_metric_values(cos_sim)
+ abs_diff_res, abs_diff_shape = get_metric_values(abs_diff)
+
+ if not os.path.exists(abs_diff_path):
+ logger.debug("saving abs_diff files")
+ write_csv(abs_diff_res, abs_diff_path, "abs_diff", tensor_cuda_out.shape, tensor_cpu_out.shape, abs_diff_shape)
+ if not os.path.exists(cos_sim_path):
+ logger.debug("saving cos_sim files")
+ write_csv(cos_sim_res, cos_sim_path, "cos_sim", tensor_cuda_out.shape, tensor_cpu_out.shape, cos_shape)
+
+ logger.info(f"Completed {model_path} layers' metrics generation with {mode} mode")
+
+for model_id, batch_size, sequence_length, max_new_token in common_shapes:
+ logger.info(f"testing model_id-{model_id}, max_new_tokens-{max_new_token}, batch_size-{batch_size}, seq_length-{sequence_length}")
+ generate_layers_metrics(model_path=model_id, batch_size=batch_size, seq_length=sequence_length, max_new_tokens=max_new_token)
diff --git a/tests/LAYERS.md b/tests/LAYERS.md
new file mode 100644
index 00000000..ae1ccde0
--- /dev/null
+++ b/tests/LAYERS.md
@@ -0,0 +1,132 @@
+# Layer Metrics Generation
+
+Generate metrics by layers to be used in tests and model enablement debugging.
+
+1. [Generate metrics by layer in GPU](./LAYERS.md#1-generate-metrics-by-layer)
+2. [Get Thresholds](./LAYERS.md#2-get-thresholds)
+3. [Apply metrics where needed](./LAYERS.md#3-apply-the-thresholds-where-needed)
+
+The steps as part of the diagram below:
+
+To see the full integration with other debugging tools, check [item 3](./LAYERS.md#3-apply-the-thresholds-where-needed).
+
+## 1. Generate Metrics by Layer
+
+The idea is to run, the prompts through the model with the pre- and post-hooks added, and then get the metrics for the outputs intercepted by each layer, as in this diagram. Then we can have a baseline with CPU/GPU for a failure threshold in AIU tests. Same idea as the [test_decoders.py](https://github.com/foundation-model-stack/aiu-fms-testing-utils/blob/main/tests/models/test_decoders.py), but for each layer. This way we can measure the discrepancies for the outputs and use the thresholds for detailed debugging problems in AIU.
+
+
+
+The script [generate_layers_metrics.py](../scripts/generate_layers_metrics.py) requires the following arguments to be run:
+
+```bash
+usage: generate_layers_metrics.py [-h] [--architecture ARCHITECTURE] [--variant VARIANT] [--model_path MODEL_PATH] --mode {generate,model-forward} --batch_sizes BATCH_SIZES --seq_lengths SEQ_LENGTHS --max_new_tokens MAX_NEW_TOKENS [--output_path OUTPUT_PATH] [--sharegpt_path SHAREGPT_PATH]
+
+Script to generate the model's metrics by layer
+
+options:
+ -h, --help show this help message and exit
+ --architecture ARCHITECTURE
+ The model architecture Eg.: hf_pretrained
+ --variant VARIANT The model variants (configuration) to benchmark. E.g. ibm-granite/granite-3.2-8b-instruct
+ --model_path MODEL_PATH
+ Paths to the directory containing model's weights (.pth files sharded by tensor parallel rank, not HF weights)
+ --mode {generate,model-forward}
+ Sets the output generation mode.
+ --batch_sizes BATCH_SIZES
+ Batch sizes separated by comma. Eg.: 1,2
+ --seq_lengths SEQ_LENGTHS
+ Sequence lengths separated by comma. Eg.: 64,2048
+ --max_new_tokens MAX_NEW_TOKENS
+ Max number of generated tokens separated by comma. Eg.: 64,128
+ --output_path OUTPUT_PATH
+ Path to save output files
+ --sharegpt_path SHAREGPT_PATH
+ Path to sharegpt data json
+```
+
+These variables support single and array values.
+
+The argument required for this script is the `--mode`, which is the generation mode desired for the output; The choices can be `generate` or `model-forward`.
+- `generate` uses FMS [generate](../scripts/generate_layers_metrics.py#L118); Itβs a high-level API that wraps many operations: forward pass, KV cache logic, sampling or greeting decoding, post-processing.
+```python
+result = generate(
+ model,
+ ids,
+ max_new_tokens=max_new_tokens,
+ use_cache=use_cache,
+ do_sample=do_sample,
+ max_seq_len=max_seq_len,
+ timing="e2e",
+ eos_token_id=None,
+ contiguous_cache=True,
+ extra_kwargs={},
+)
+```
+- `model-forward` will call [model.forward](../scripts/generate_layers_metrics.py#L135); Avoids introducing noise from sampling, past key caching, etc.
+```python
+result = model.forward(
+ ids,
+ use_cache=use_cache
+ )
+```
+
+### How to run
+
+Once all is set up, we can generate the CSV metrics:
+
+```bash
+cd aiu-fms-testing-utils/tests/resources
+
+mkdir /tmp/output
+
+python3 generate_layers_metrics.py --mode model-forward --variant ibm-granite/granite-3.2-8b-instruct --architecture hf_pretrained --batch_sizes 1 --seq_lengths 64 --max_new_tokens 128
+```
+The files should get created at `/tmp/output` dir:
+```bash
+ibm-granite--granite-3.2-8b-instruct_max-new-tokens-128_batch-size-1_seq-length-64_dtype-float16--model.base_model.layers7.ln.abs_diff.csv
+ibm-granite--granite-3.2-8b-instruct_max-new-tokens-128_batch-size-1_seq-length-64_dtype-float16--model.base_model.layers7.ln.cos_sim.csv
+ibm-granite--granite-3.2-8b-instruct_max-new-tokens-128_batch-size-1_seq-length-64_dtype-float16--model.base_model.layers8.attn.dense.abs_diff.csv
+ibm-granite--granite-3.2-8b-instruct_max-new-tokens-128_batch-size-1_seq-length-64_dtype-float16--model.base_model.layers8.attn.dense.cos_sim.csv
+```
+
+## 2. Get Thresholds
+
+To get the second step of the flow and get the thresholds by layer, run:
+```bash
+cd /aiu-fms-testing-utils/tests/resources
+
+python3 get_thresholds.py --models ibm-granite/granite-3.2-8b-instruct --metrics abs_diff cos_sim_avg cos_sim_men --file_base /tmp/output --layer_io
+```
+It should print the metric of each layer:
+```bash
+2025-07-09 19:02:40,657 found 484 layers metric files
+2025-07-09 19:02:40,674 Layer model.base_model.embedding abs_diff_linalg_norm = 1.7258892434335918e-07
+2025-07-09 19:02:40,690 Layer model.base_model.layers0.ln abs_diff_linalg_norm = 0.4083323414747196
+2025-07-09 19:02:40,707 Layer model.base_model.layers0.attn.in_proj.query abs_diff_linalg_norm = 0.7099368339133884
+2025-07-09 19:02:40,712 Layer model.base_model.layers0.attn.in_proj.key abs_diff_linalg_norm = 0.40915828503373886
+2025-07-09 19:02:40,716 Layer model.base_model.layers0.attn.in_proj.value abs_diff_linalg_norm = 0.12381335209555287
+2025-07-09 19:02:40,721 Layer model.base_model.layers0.attn.in_proj abs_diff_linalg_norm = 0.12381335209555287
+[...]
+2025-07-09 19:03:27,029 Layer model.base_model.layers39.attn.in_proj.value cos_sim_avg = 0.9999685110524297
+2025-07-09 19:03:27,029 Layer model.base_model.layers39.attn.in_proj cos_sim_avg = 0.9999685110524297
+2025-07-09 19:03:27,029 Layer model.base_model.layers39.attn.dense cos_sim_avg = 0.9999954961240292
+2025-07-09 19:03:27,029 Layer model.base_model.layers39.ff_ln cos_sim_avg = 1.0000354265794158
+2025-07-09 19:03:27,029 Layer model.base_model.layers39.ff_sub_layer.wg cos_sim_avg = 1.0000474276021123
+2025-07-09 19:03:27,029 Layer model.base_model.layers39.ff_sub_layer.a cos_sim_avg = 1.0000188555568457
+[...]
+2025-07-09 19:03:27,055 Layer model.base_model.layers0.attn.in_proj.query cos_sim_mean = 0.9999569654464722
+2025-07-09 19:03:27,055 Layer model.base_model.layers0.attn.in_proj.key cos_sim_mean = 1.000030318275094
+2025-07-09 19:03:27,055 Layer model.base_model.layers0.attn.in_proj.value cos_sim_mean = 0.9999886471778154
+2025-07-09 19:03:27,055 Layer model.base_model.layers0.attn.in_proj cos_sim_mean = 0.9999886471778154
+2025-07-09 19:03:27,055 Layer model.base_model.layers0.attn.dense cos_sim_mean = 1.0000049602240324
+2025-07-09 19:03:27,055 Layer model.base_model.layers0.ff_ln cos_sim_mean = 0.9999961135908961
+
+```
+Also, a JSON file is saved to the same output dir. A sample file can be found at: [sample_layer_th.json](https://github.com/flaviabeo/aiu-fms-testing-utils/blob/generate_metrics_layers/tests/resources/sample_layer_th.json)
+
+## 3. Apply the thresholds where needed
+
+In case of AIU debugging tools, the thresholds will be applied to compare AIU outputs with CPU, and then assert if the differences are within the thresholds generated. Below, is an architecture of the full integration:
+
+
+The box named `deepview layer debug` has the diagram of how the model layers outputs are generated to be compared against the CPU results. This is important so that the debug tools can catch operations and layers that have issues in their enablement for AIU hardware.
\ No newline at end of file
diff --git a/tests/MODEL.md b/tests/MODEL.md
new file mode 100644
index 00000000..0f7ba612
--- /dev/null
+++ b/tests/MODEL.md
@@ -0,0 +1,313 @@
+# Model Tests
+How to run the pytest test suites at [aiu-fms-testing-utils](https://github.com/aiu-fms-testing-utils/tree/main/tests/models).
+
+1. [Generate metrics in GPU](MODEL.md#1-run-first-on-gpu)
+2. [Get Thresholds](MODEL.md#2-get-thresholds)
+3. [Apply thresholds into test_decoders](MODEL.md#3-apply-thresholds-in-aiu-test_decoders)
+4. [Run test_model_expectations](MODEL.md#4-run-test_model_expectations)
+
+
+
+## The test scripts
+
+- **test_decoders** - this will test the decoder models (text-generation) with certain shapes. Most of this is configurable (model, batch_size, prompt_length, max_new_tokens, metrics_thresholds, failure_rate_thresholds, mini models, etc.)
+Example:
+```bash
+# Note: you might need an hf_token if the model requires it (this will download)
+export FMS_TEST_SHAPES_COMMON_BATCH_SIZES=1
+export FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS=128
+export FMS_TEST_SHAPES_COMMON_MODEL_PATHS=/local-path/granite-20b-code-cobol-v1/
+export FMS_TEST_SHAPES_USE_MICRO_MODELS=0
+pytest tests/models/test_decoders.py
+```
+The above will test shapes batch_size 1, with sequence length 128 of granite-20b-code-cobol-v1. We can set `FMS_TEST_SHAPES_USE_MICRO_MODELS=0` for not using micro models. Or set it to `FMS_TEST_SHAPES_USE_MICRO_MODELS=1` and add the micro models version to the model paths.
+
+- **test_model_expectations** - this test will capture a snapshot in time of what a randomly initialized model would produce on the AIU. To add a model to this, you simply add it to either the models list or tuple_output_models list which will generate 2 expectation tests. The first time you run this test, you run it with --capture_expectation which will create a resource file with the expected output. The next time you run it, you run without the --capture_expectation and all should pass.
+
+### Thresholds for the tests baselines for `test_decoders`
+
+The `test_decoders.py` file contains tests written for models that have **decoder** architecture. For each model to be tested, specific metrics baselines need to be created by following the next steps in this documentation. Four different metrics are generated with top k per token as base lines for these tests; Cross entropy loss per token, probability mean, probability standard deviation and absolute diff mean.
+
+- **cross_entropy**: Cross entropy is a measure from information theory that quantifies the difference between two probability distributions. Cross entropy serves as a measure of the differences when comparing expected generated tokens and the actual output of the model. Quantifying the distance between the ground-truth distribution and the predicted distribution.
+A lower cross entropy indicates a closer match in expected versus generated.
+- **prob_mean**: Probability Mean typically refers to the average probability assigned by the model to a sequence of words or tokens. It's a measure of how well the model understands and predicts language, with lower mean probabilities often indicating a poorer model that struggles to generate coherent or plausible text.
+- **prob_std**: Probability standard deviation assesses how spread out or consistent the model's predictions are when it assigns probabilities to different possible outcomes. A high standard deviation indicates wide variation in the model's certainty, while a low standard deviation suggests more consistent and confident prediction
+- **diff_mean**: The difference of the average or central tendency of a set of data points, often used to measure the model's performance. It can also refer to the intended purpose or interpretation of a text or sentence produced by the model.
+
+They are calculated in lines [228 - 231 at generate_metrics.py](../scripts/generate_metrics.py#L253) script.
+```python
+cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()(r, t.softmax(dim=1).to(dtype=torch.float32))
+prob_mean = lambda r, t: torch.mean((r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32)) - 1.0)
+prob_std = lambda r, t: torch.std(r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32))
+diff_mean = lambda r, t: torch.mean(torch.abs(r.softmax(dim=1).to(dtype=torch.float32) - t.softmax(dim=1).to(dtype=torch.float32)))
+```
+More at [pytorch.org](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html), [Yiren,Wang](https://courses.grainger.illinois.edu/ece598pv/fa2017/Lecture13_LM_YirenWang.pdf), [Li, Wang, Shang Et al.](https://arxiv.org/abs/2412.12177#:~:text=%5B2412.12177%5D%20Model%2Ddiff:,%3E%20cs%20%3E%20arXiv:2412.12177) and [Wu,Hilton](https://arxiv.org/html/2410.13211v1).
+
+
+This metrics will be set at the [fail thresholds](./models/test_decoders.py#L146), so **cross_entropy** and **diff_mean** can be used to compare between the GPU generated text output by the same model in AIU.
+
+## 1. Run first on GPU
+
+Set shapes:
+```bash
+export MODEL_PATH=/model-path/
+export MAX_NEW_TOKENS=128
+export BATCH_SIZES=1
+export SEQ_LENS=64
+export DEFAULT_TYPES="fp16"
+export DS_PATH=/resources/sharegpt/share_gpt.json
+```
+
+Then run the command for the metrics script:
+```bash
+python generate_metrics.py --architecture=hf_pretrained --model_path=$MODEL_PATH --tokenizer=$MODEL_PATH --unfuse_weights --output_dir=/tmp/aiu-fms-testing-utils/output/ --compile_dynamic --max_new_tokens=$MAX_NEW_TOKENS --min_pad_length=$SEQ_LENS --batch_size=$BATCH_SIZES --default_dtype=$DEFAULT_TYPES --sharegpt_path=$DS_PATH --num_test_tokens_per_sequence=1024
+```
+
+This will generate csv files with the results of the metrics calculation. Typically, this is run with batch size 1, 8 and sequency length 64, 2048 (4 runs in total). Then, we can run [get_thresholds.py](./resources/get_thresholds.py) to summarize the results and get the single values for each metric as the following.
+
+At the output path, you will see the out and csv files generated as the sample in the following lines:
+```bash
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.ce.csv
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.0.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.1.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.2.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.3.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.4.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.5.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.6.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.7.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.0.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.1.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.2.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.3.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.4.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.5.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.6.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.7.out
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.diff_mean.csv
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.prob_mean.csv
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.prob_std.csv
+```
+## 2. Get Thresholds
+Get the thresholds by running the [get_thresholds.py](./resources/get_thresholds.py):
+```bash
+python3 get_thresholds.py --models /tmp/aiu-fms-testing-utils/models/model-name-version-v1 --metrics diff_mean ce --file_base /tmp/aiu-fms-testing-utils/output
+```
+After running these scripts in namespace with 1 GPU, these were the thresholds generated:
+
+```bash
+python3 get_thresholds.py --models /tmp/aiu-fms-testing-utils/models/Mistral-7B-Instruct-v0.3 --metrics diff_mean ce --file_base /tmp/aiu-fms-testing-utils/output
+found 7 metric files
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3 diff_mean 0.0007839603102183846
+found 7 metric files
+--tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3 ce 2.8364005851745624
+```
+
+These can now be used for the model testing scripts at AIU.
+
+## 3. Apply thresholds in AIU `test_decoders`
+
+These are the variables set at the deployment:
+
+| Name | Value
+| ------------- | ----------------
+| FMS_TEST_SHAPES_COMMON_MODEL_PATHS | mistralai/Mistral-7B-Instruct-v0.3
+| FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1 | 1
+| FMS_TEST_SHAPES_COMMON_BATCH_SIZES | 1
+| FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS | 64
+| FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS | 16
+| FMS_TEST_SHAPES_USE_MICRO_MODELS | 0
+| FMS_TEST_SHAPES_METRICS_THRESHOLD | 2.8364005851745624,0.0007839603102183846
+
+
+> Set `FMS_TEST_SHAPES_METRICS_THRESHOLD` in case there is no need to add the model to the default ones. No code changes needed, just this environment variable set with the metrics values. Set `FMS_TEST_SHAPES_VALIDATION_INFO_DIR` to speed up the tests considerably when testing larger models by using the output logits saved from generating the metrics. Set `FMS_TEST_SHAPES_FAILURE_THRESHOLD` if you would like to relax the threshold - default is `0.01`.
+
+Add the new numbers at the end of the [dictionary](./models/test_decoders.py#L116):
+```python
+# thresholds are chosen based on 1024 tokens per sequence
+# 1% error threshold rate between cpu fp32 and cuda fp16
+# if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above
+# threshold key is (model_id, is_tiny_model)
+fail_thresholds = {
+ (LLAMA_3p1_8B_INSTRUCT, True): (
+ 3.7392955756187423,
+ .001, # FIXME: compute
+ ),
+ (GRANITE_3p2_8B_INSTRUCT, True): (
+ 2.996668996810913,
+ .001, # FIXME: compute
+ ),
+ (GRANITE_20B_CODE_INSTRUCT_8K, True): (
+ 3.7392955756187423, # FIXME: compute -- setting to micro llama 3.1 8b instruct
+ .001, # FIXME: compute
+ ),
+ (LLAMA_3p1_70B_INSTRUCT, True): (
+ 3.8235735702514626,
+ .001, # FIXME: compute
+ ),
+ (LLAMA_3p1_8B_INSTRUCT, False): (
+ 2.6994638133048965,
+ 0.00047589250549208347,
+ ),
+ (GRANITE_3p2_8B_INSTRUCT, False): (
+ 2.3919514417648315,
+ 0.0005767398688476533,
+ ),
+ (GRANITE_20B_CODE_INSTRUCT_8K, False): (
+ 2.640706129074097,
+ 0.00034344267623964697,
+ ),
+ (LLAMA_3p1_70B_INSTRUCT, False): (
+ 2.841279556751251,
+ 0.0044301633024588115,
+ ),
+}
+```
+
+The command to run is:
+```bash
+pytest tests/models/test_decoders.py -vv
+```
+Add the `-vv` for verbose output.
+
+### Test Results Samples
+
+Here is a result sample of the test outputs:
+
+```bash
+Starting to run pytest tests/models/test_decoders.py
+[ 0/ 1]: Sentient AIU: Enabled
+============================= test session starts ==============================
+platform linux -- Python 3.11.9, pytest-8.3.5, pluggy-1.5.0
+rootdir: /tmp/aiu-fms-testing-utils
+plugins: durations-1.4.0, env-1.1.5
+collected 1 item
+
+tests/models/test_decoders.py . [100%]
+
+=============================== warnings summary ===============================
+../foundation-model-stack/fms/triton/pytorch_ops.py:103
+ /tmp/foundation-model-stack/fms/triton/pytorch_ops.py:103: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
+ @torch.library.impl_abstract("moe::moe_mm")
+
+-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
+============================= fixture duration top =============================
+total name num avg min
+0:00:00.000140 grand total 5 0:00:00.000014 0:00:00.000012
+============================ test call duration top ============================
+total name num avg min
+0:02:18.965102 test_common_shapes 1 0:02:18.965102 0:02:18.965102
+0:02:18.965102 grand total 1 0:02:18.965102 0:02:18.965102
+=========================== test setup duration top ============================
+total name num avg min
+0:00:00.000553 grand total 1 0:00:00.000553 0:00:00.000553
+========================== test teardown duration top ==========================
+total name num avg min
+0:00:00.000969 grand total 1 0:00:00.000969 0:00:00.000969
+=================== 1 passed, 1 warning in 140.35s (0:02:20) ===================
+Finished running pytests
+```
+In case the thresholds fails:
+```bash
+[ 0/ 1]: testing model=/mnt/aiu-models-en-shared/models/hf/Mistral-7B-Instruct-v0.3, batch_size=1, seq_length=64, max_new_tokens=16, micro_model=False
+[ 0/ 1]: AIU warmup
+Using AIU_TOPO_FILE=/etc/aiu/topo.json
+[ 0/ 1]: PT compile complete, took 211.912s
+[ 0/ 1]: cpu validation info extracted for validation level 0 and validation level 1 (iter=0)
+[ 0/ 1]: aiu validation info extracted for validation level 0
+[ 0/ 1]: failed validation level 0, testing validation level 1
+[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=0
+[ 0/ 1]: cpu validation info extracted for validation level 1 - iter=1
+[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=1
+[ 0/ 1]: cpu validation info extracted for validation level 1 - iter=2
+[...] (iteractions removed for better readability)
+[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=60
+[ 0/ 1]: cpu validation info extracted for validation level 1 - iter=61
+[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=61
+[ 0/ 1]: cpu validation info extracted for validation level 1 - iter=62
+[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=62
+[ 0/ 1]: cpu validation info extracted for validation level 1 - iter=63
+[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=63
+[ 0/ 1]: mean diff failure rate: 0.7638888888888888
+[ 0/ 1]: cross entropy loss failure rate: 0.000992063492063492
+===================================================================================== fixture duration top =====================================================================================
+total name num avg min
+0:00:00.000130 grand total 5 0:00:00.000012 0:00:00.000009
+==================================================================================== test call duration top ====================================================================================
+total name num avg min
+0:16:31.480337 test_common_shapes 1 0:16:31.480337 0:16:31.480337
+0:16:31.480337 grand total 1 0:16:31.480337 0:16:31.480337
+=================================================================================== test setup duration top ====================================================================================
+total name num avg min
+0:00:00.000555 grand total 1 0:00:00.000555 0:00:00.000555
+================================================================================== test teardown duration top ==================================================================================
+total name num avg min
+0:00:00.001416 grand total 1 0:00:00.001416 0:00:00.001416
+=================================================================================== short test summary info ====================================================================================
+FAILED tests/models/test_decoders.py::test_common_shapes[/mnt/aiu-models-en-shared/models/hf/Mistral-7B-Instruct-v0.3-1-64-16] - AssertionError: failure rate for mean diff was too high: 0.7638888888888888
+assert 0.7638888888888888 < 0.01
+```
+## 4. Run `test_model_expectations`
+
+- First add the desired model to the [decoder_models](./models/test_model_expectations.py#L55) variable.
+- If the models tested are too big, it is a valid option to add the micro model version for this specific test.
+- 4.1 Run `pytest tests/models/test_model_expectations.py::TestAIUDecoderModels --capture_expectation` to save the model weights.
+After that you will get an output like this:
+```bash
+FAILED tests/models/test_model_expectations.py::TestAIUDecoderModels::test_model_output[/tmp/models/mistralai/Mistral-7B-Instruct-v0.3-True] - Failed: Signature file has been saved, please re-run the tests without --capture_expectation
+FAILED tests/models/test_model_expectations.py::TestAIUDecoderModels::test_model_weight_keys[/tmp/models/mistralai/Mistral-7B-Instruct-v0.3-True] - Failed: Weights Key file has been saved, please re-run the tests without --capture_expectation
+```
+This will tell that the weights and signature have been saved, so you can run the complete suite again to get the tests results.
+- 4.2 Then running the complete suite:
+
+```bash
+[1000780000@e2e-vllm-dt2-646f66647b-68dh6 aiu-fms-testing-utils]$ pytest tests/models/test_model_expectations.py::TestAIUDecoderModels -vv
+[ 0/ 1]: Sentient AIU: Enabled
+===================================================================================== test session starts ======================================================================================
+platform linux -- Python 3.12.5, pytest-8.3.5, pluggy-1.5.0 -- /usr/bin/python3.12
+cachedir: .pytest_cache
+rootdir: /tmp/aiu-fms-testing-utils
+plugins: durations-1.5.2, env-1.1.5
+collected 3 items
+
+tests/models/test_model_expectations.py::TestAIUDecoderModels::test_model_output[/tmp/models/mistralai/Mistral-7B-Instruct-v0.3-False] <- ../foundation-model-stack/fms/testing/_internal/model_test_suite.py PASSED [ 33%]
+tests/models/test_model_expectations.py::TestAIUDecoderModels::test_model_weight_keys[/tmp/models/mistralai/Mistral-7B-Instruct-v0.3-False] <- ../foundation-model-stack/fms/testing/_internal/model_test_suite.py PASSED [ 66%]
+tests/models/test_model_expectations.py::TestAIUDecoderModels::test_model_unfused[/tmp/models/mistralai/Mistral-7B-Instruct-v0.3] SKIPPED (All AIU models are already unfused) [100%]
+
+===================================================================================== fixture duration top =====================================================================================
+total name num avg min
+0:00:02.201162 uninitialized_model 1 0:00:02.201162 0:00:02.201162
+0:00:00.051478 model 1 0:00:00.051478 0:00:00.051478
+0:00:02.252951 grand total 6 0:00:00.000135 0:00:00.000046
+==================================================================================== test call duration top ====================================================================================
+total name num avg min
+0:03:05.951278 TestAIUDecoderModels::test_model_output 1 0:03:05.951278 0:03:05.951278
+0:03:05.954470 grand total 3 0:00:00.003095 0:00:00.000097
+=================================================================================== test setup duration top ====================================================================================
+total name num avg min
+0:00:00.002004 grand total 3 0:00:00.000289 0:00:00.000102
+================================================================================== test teardown duration top ==================================================================================
+total name num avg min
+0:00:00.000363 grand total 3 0:00:00.000090 0:00:00.000077
+=========================================================================== 2 passed, 1 skipped in 189.01s (0:03:09) ===========================================================================
+
+```
+
+In this case, the model tested was a decoder model with a single output, the `TestAIUDecoderModels` is the most important case. In the next section, check the applicability for the [TestAIUModelsTupleOutput](./MODEL.md#case-of-multiple-output---testaiumodelstupleoutput) cases.
+
+#### Case of multiple output - TestAIUModelsTupleOutput
+
+The case **TestAIUModelsTupleOutput** is applicable if the model being tested has output of more than one tensor. Like the model in the example default [tuple_output_models](./models/test_model_expectations.py#L76), is a RoBERTa model that can output in this different format.
+
+- Add the model also to [tuple_output_models](./models/test_model_expectations.py#L76).
+- 4.1 Run `pytest tests/models/test_model_expectations.py::TestAIUModelsTupleOutput --capture_expectation` to save the model weights;
+
+```bash
+tests/models/test_model_expectations.py::TestAIUModelsTupleOutput::test_model_output[/ibm-dmf/models/watsonx/shared/granite-20b-code-cobol-v1/20240603-False] <- ../foundation-model-stack/fms/testing/_internal/model_test_suite.py PASSED [ 66%]
+tests/models/test_model_expectations.py::TestAIUModelsTupleOutput::test_model_weight_keys[/ibm-dmf/models/watsonx/shared/granite-20b-code-cobol-v1/20240603-False] <- ../foundation-model-stack/fms/testing/_internal/model_test_suite.py PASSED [ 83%]
+tests/models/test_model_expectations.py::TestAIUModelsTupleOutput::test_model_unfused[/ibm-dmf/models/watsonx/shared/granite-20b-code-cobol-v1/20240603] SKIPPED [100%]
+```
+
+> When adding new models expectations, please include in the PR with capture expectation tests added, the date of the image used to generate the file.
+
+Check this example of a PR for adding a new model expectations' files and results [here](https://github.com/foundation-model-stack/aiu-fms-testing-utils/pull/48).
diff --git a/tests/README.md b/tests/README.md
index 836181c3..07a9e533 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -1,313 +1,6 @@
-# Model Tests
-How to run the pytest test suites at [aiu-fms-testing-utils](https://github.com/aiu-fms-testing-utils/tree/main/tests/models).
+# Tests and Metrics generation
-1. [Generate metrics in GPU](README.md#1-run-first-on-gpu)
-2. [Get Thresholds](README.md#2-get-thresholds)
-3. [Apply thresholds into test_decoders](README.md#3-apply-thresholds-in-aiu-test_decoders)
-4. [Run test_model_expectations](README.md#4-run-test_model_expectations)
+In these docs you will find:
-
-
-## The test scripts
-
-- **test_decoders** - this will test the decoder models (text-generation) with certain shapes. Most of this is configurable (model, batch_size, prompt_length, max_new_tokens, metrics_thresholds, failure_rate_thresholds, mini models, etc.)
-Example:
-```bash
-# Note: you might need an hf_token if the model requires it (this will download)
-export FMS_TEST_SHAPES_COMMON_BATCH_SIZES=1
-export FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS=128
-export FMS_TEST_SHAPES_COMMON_MODEL_PATHS=/local-path/granite-20b-code-cobol-v1/
-export FMS_TEST_SHAPES_USE_MICRO_MODELS=0
-pytest tests/models/test_decoders.py
-```
-The above will test shapes batch_size 1, with sequence length 128 of granite-20b-code-cobol-v1. We can set `FMS_TEST_SHAPES_USE_MICRO_MODELS=0` for not using micro models. Or set it to `FMS_TEST_SHAPES_USE_MICRO_MODELS=1` and add the micro models version to the model paths.
-
-- **test_model_expectations** - this test will capture a snapshot in time of what a randomly initialized model would produce on the AIU. To add a model to this, you simply add it to either the models list or tuple_output_models list which will generate 2 expectation tests. The first time you run this test, you run it with --capture_expectation which will create a resource file with the expected output. The next time you run it, you run without the --capture_expectation and all should pass.
-
-### Thresholds for the tests baselines for `test_decoders`
-
-The `test_decoders.py` file contains tests written for models that have **decoder** architecture. For each model to be tested, specific metrics baselines need to be created by following the next steps in this documentation. Four different metrics are generated with top k per token as base lines for these tests; Cross entropy loss per token, probability mean, probability standard deviation and absolute diff mean.
-
-- **cross_entropy**: Cross entropy is a measure from information theory that quantifies the difference between two probability distributions. Cross entropy serves as a measure of the differences when comparing expected generated tokens and the actual output of the model. Quantifying the distance between the ground-truth distribution and the predicted distribution.
-A lower cross entropy indicates a closer match in expected versus generated.
-- **prob_mean**: Probability Mean typically refers to the average probability assigned by the model to a sequence of words or tokens. It's a measure of how well the model understands and predicts language, with lower mean probabilities often indicating a poorer model that struggles to generate coherent or plausible text.
-- **prob_std**: Probability standard deviation assesses how spread out or consistent the model's predictions are when it assigns probabilities to different possible outcomes. A high standard deviation indicates wide variation in the model's certainty, while a low standard deviation suggests more consistent and confident prediction
-- **diff_mean**: The difference of the average or central tendency of a set of data points, often used to measure the model's performance. It can also refer to the intended purpose or interpretation of a text or sentence produced by the model.
-
-They are calculated in lines [228 - 231 at generate_metrics.py](../scripts/generate_metrics.py#L253) script.
-```python
-cross_entropy = lambda r, t: torch.nn.CrossEntropyLoss()(r, t.softmax(dim=1).to(dtype=torch.float32))
-prob_mean = lambda r, t: torch.mean((r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32)) - 1.0)
-prob_std = lambda r, t: torch.std(r.softmax(dim=1).to(dtype=torch.float32) / t.softmax(dim=1).to(dtype=torch.float32))
-diff_mean = lambda r, t: torch.mean(torch.abs(r.softmax(dim=1).to(dtype=torch.float32) - t.softmax(dim=1).to(dtype=torch.float32)))
-```
-More at [pytorch.org](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html), [Yiren,Wang](https://courses.grainger.illinois.edu/ece598pv/fa2017/Lecture13_LM_YirenWang.pdf), [Li, Wang, Shang Et al.](https://arxiv.org/abs/2412.12177#:~:text=%5B2412.12177%5D%20Model%2Ddiff:,%3E%20cs%20%3E%20arXiv:2412.12177) and [Wu,Hilton](https://arxiv.org/html/2410.13211v1).
-
-
-This metrics will be set at the [fail thresholds](./models/test_decoders.py#L146), so **cross_entropy** and **diff_mean** can be used to compare between the GPU generated text output by the same model in AIU.
-
-## 1. Run first on GPU
-
-Set shapes:
-```bash
-export MODEL_PATH=/model-path/
-export MAX_NEW_TOKENS=128
-export BATCH_SIZES=1
-export SEQ_LENS=64
-export DEFAULT_TYPES="fp16"
-export DS_PATH=/resources/sharegpt/share_gpt.json
-```
-
-Then run the command for the metrics script:
-```bash
-python generate_metrics.py --architecture=hf_pretrained --model_path=$MODEL_PATH --tokenizer=$MODEL_PATH --unfuse_weights --output_dir=/tmp/aiu-fms-testing-utils/output/ --compile_dynamic --max_new_tokens=$MAX_NEW_TOKENS --min_pad_length=$SEQ_LENS --batch_size=$BATCH_SIZES --default_dtype=$DEFAULT_TYPES --sharegpt_path=$DS_PATH --num_test_tokens_per_sequence=1024
-```
-
-This will generate csv files with the results of the metrics calculation. Typically, this is run with batch size 1, 8 and sequency length 64, 2048 (4 runs in total). Then, we can run [get_thresholds.py](./resources/get_thresholds.py) to summarize the results and get the single values for each metric as the following.
-
-At the output path, you will see the out and csv files generated as the sample in the following lines:
-```bash
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.ce.csv
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.0.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.1.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.2.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.3.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.4.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.5.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.6.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cpu_validation_info.7.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.0.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.1.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.2.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.3.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.4.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.5.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.6.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.cuda_validation_info.7.out
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.diff_mean.csv
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.prob_mean.csv
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3_max-new-tokens-128_batch-size-8_seq-length64_dtype-fp16.prob_std.csv
-```
-## 2. Get Thresholds
-Get the thresholds by running the [get_thresholds.py](./resources/get_thresholds.py):
-```bash
-python3 get_thresholds.py --models /tmp/aiu-fms-testing-utils/models/model-name-version-v1 --metrics diff_mean ce --file_base /tmp/aiu-fms-testing-utils/output
-```
-After running these scripts in namespace with 1 GPU, these were the thresholds generated:
-
-```bash
-python3 get_thresholds.py --models /tmp/aiu-fms-testing-utils/models/Mistral-7B-Instruct-v0.3 --metrics diff_mean ce --file_base /tmp/aiu-fms-testing-utils/output
-found 7 metric files
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3 diff_mean 0.0007839603102183846
-found 7 metric files
---tmp--aiu-fms-testing-utils--models--Mistral-7B-Instruct-v0.3 ce 2.8364005851745624
-```
-
-These can now be used for the model testing scripts at AIU.
-
-## 3. Apply thresholds in AIU `test_decoders`
-
-These are the variables set at the deployment:
-
-| Name | Value
-| ------------- | ----------------
-| FMS_TEST_SHAPES_COMMON_MODEL_PATHS | mistralai/Mistral-7B-Instruct-v0.3
-| FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1 | 1
-| FMS_TEST_SHAPES_COMMON_BATCH_SIZES | 1
-| FMS_TEST_SHAPES_COMMON_SEQ_LENGTHS | 64
-| FMS_TEST_SHAPES_COMMON_MAX_NEW_TOKENS | 16
-| FMS_TEST_SHAPES_USE_MICRO_MODELS | 0
-| FMS_TEST_SHAPES_METRICS_THRESHOLD | 2.8364005851745624,0.0007839603102183846
-
-
-> Set `FMS_TEST_SHAPES_METRICS_THRESHOLD` in case there is no need to add the model to the default ones. No code changes needed, just this environment variable set with the metrics values. Set `FMS_TEST_SHAPES_VALIDATION_INFO_DIR` to speed up the tests considerably when testing larger models by using the output logits saved from generating the metrics. Set `FMS_TEST_SHAPES_FAILURE_THRESHOLD` if you would like to relax the threshold - default is `0.01`.
-
-Add the new numbers at the end of the [dictionary](./models/test_decoders.py#L116):
-```python
-# thresholds are chosen based on 1024 tokens per sequence
-# 1% error threshold rate between cpu fp32 and cuda fp16
-# if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above
-# threshold key is (model_id, is_tiny_model)
-fail_thresholds = {
- (LLAMA_3p1_8B_INSTRUCT, True): (
- 3.7392955756187423,
- .001, # FIXME: compute
- ),
- (GRANITE_3p2_8B_INSTRUCT, True): (
- 2.996668996810913,
- .001, # FIXME: compute
- ),
- (GRANITE_20B_CODE_INSTRUCT_8K, True): (
- 3.7392955756187423, # FIXME: compute -- setting to micro llama 3.1 8b instruct
- .001, # FIXME: compute
- ),
- (LLAMA_3p1_70B_INSTRUCT, True): (
- 3.8235735702514626,
- .001, # FIXME: compute
- ),
- (LLAMA_3p1_8B_INSTRUCT, False): (
- 2.6994638133048965,
- 0.00047589250549208347,
- ),
- (GRANITE_3p2_8B_INSTRUCT, False): (
- 2.3919514417648315,
- 0.0005767398688476533,
- ),
- (GRANITE_20B_CODE_INSTRUCT_8K, False): (
- 2.640706129074097,
- 0.00034344267623964697,
- ),
- (LLAMA_3p1_70B_INSTRUCT, False): (
- 2.841279556751251,
- 0.0044301633024588115,
- ),
-}
-```
-
-The command to run is:
-```bash
-pytest tests/models/test_decoders.py -vv
-```
-Add the `-vv` for verbose output.
-
-### Test Results Samples
-
-Here is a result sample of the test outputs:
-
-```bash
-Starting to run pytest tests/models/test_decoders.py
-[ 0/ 1]: Sentient AIU: Enabled
-============================= test session starts ==============================
-platform linux -- Python 3.11.9, pytest-8.3.5, pluggy-1.5.0
-rootdir: /tmp/aiu-fms-testing-utils
-plugins: durations-1.4.0, env-1.1.5
-collected 1 item
-
-tests/models/test_decoders.py . [100%]
-
-=============================== warnings summary ===============================
-../foundation-model-stack/fms/triton/pytorch_ops.py:103
- /tmp/foundation-model-stack/fms/triton/pytorch_ops.py:103: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
- @torch.library.impl_abstract("moe::moe_mm")
-
--- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
-============================= fixture duration top =============================
-total name num avg min
-0:00:00.000140 grand total 5 0:00:00.000014 0:00:00.000012
-============================ test call duration top ============================
-total name num avg min
-0:02:18.965102 test_common_shapes 1 0:02:18.965102 0:02:18.965102
-0:02:18.965102 grand total 1 0:02:18.965102 0:02:18.965102
-=========================== test setup duration top ============================
-total name num avg min
-0:00:00.000553 grand total 1 0:00:00.000553 0:00:00.000553
-========================== test teardown duration top ==========================
-total name num avg min
-0:00:00.000969 grand total 1 0:00:00.000969 0:00:00.000969
-=================== 1 passed, 1 warning in 140.35s (0:02:20) ===================
-Finished running pytests
-```
-In case the thresholds fails:
-```bash
-[ 0/ 1]: testing model=/mnt/aiu-models-en-shared/models/hf/Mistral-7B-Instruct-v0.3, batch_size=1, seq_length=64, max_new_tokens=16, micro_model=False
-[ 0/ 1]: AIU warmup
-Using AIU_TOPO_FILE=/etc/aiu/topo.json
-[ 0/ 1]: PT compile complete, took 211.912s
-[ 0/ 1]: cpu validation info extracted for validation level 0 and validation level 1 (iter=0)
-[ 0/ 1]: aiu validation info extracted for validation level 0
-[ 0/ 1]: failed validation level 0, testing validation level 1
-[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=0
-[ 0/ 1]: cpu validation info extracted for validation level 1 - iter=1
-[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=1
-[ 0/ 1]: cpu validation info extracted for validation level 1 - iter=2
-[...] (iteractions removed for better readability)
-[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=60
-[ 0/ 1]: cpu validation info extracted for validation level 1 - iter=61
-[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=61
-[ 0/ 1]: cpu validation info extracted for validation level 1 - iter=62
-[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=62
-[ 0/ 1]: cpu validation info extracted for validation level 1 - iter=63
-[ 0/ 1]: aiu validation info extracted for validation level 1 - iter=63
-[ 0/ 1]: mean diff failure rate: 0.7638888888888888
-[ 0/ 1]: cross entropy loss failure rate: 0.000992063492063492
-===================================================================================== fixture duration top =====================================================================================
-total name num avg min
-0:00:00.000130 grand total 5 0:00:00.000012 0:00:00.000009
-==================================================================================== test call duration top ====================================================================================
-total name num avg min
-0:16:31.480337 test_common_shapes 1 0:16:31.480337 0:16:31.480337
-0:16:31.480337 grand total 1 0:16:31.480337 0:16:31.480337
-=================================================================================== test setup duration top ====================================================================================
-total name num avg min
-0:00:00.000555 grand total 1 0:00:00.000555 0:00:00.000555
-================================================================================== test teardown duration top ==================================================================================
-total name num avg min
-0:00:00.001416 grand total 1 0:00:00.001416 0:00:00.001416
-=================================================================================== short test summary info ====================================================================================
-FAILED tests/models/test_decoders.py::test_common_shapes[/mnt/aiu-models-en-shared/models/hf/Mistral-7B-Instruct-v0.3-1-64-16] - AssertionError: failure rate for mean diff was too high: 0.7638888888888888
-assert 0.7638888888888888 < 0.01
-```
-## 4. Run `test_model_expectations`
-
-- First add the desired model to the [decoder_models](./models/test_model_expectations.py#L55) variable.
-- If the models tested are too big, it is a valid option to add the micro model version for this specific test.
-- 4.1 Run `pytest tests/models/test_model_expectations.py::TestAIUDecoderModels --capture_expectation` to save the model weights.
-After that you will get an output like this:
-```bash
-FAILED tests/models/test_model_expectations.py::TestAIUDecoderModels::test_model_output[/tmp/models/mistralai/Mistral-7B-Instruct-v0.3-True] - Failed: Signature file has been saved, please re-run the tests without --capture_expectation
-FAILED tests/models/test_model_expectations.py::TestAIUDecoderModels::test_model_weight_keys[/tmp/models/mistralai/Mistral-7B-Instruct-v0.3-True] - Failed: Weights Key file has been saved, please re-run the tests without --capture_expectation
-```
-This will tell that the weights and signature have been saved, so you can run the complete suite again to get the tests results.
-- 4.2 Then running the complete suite:
-
-```bash
-[1000780000@e2e-vllm-dt2-646f66647b-68dh6 aiu-fms-testing-utils]$ pytest tests/models/test_model_expectations.py::TestAIUDecoderModels -vv
-[ 0/ 1]: Sentient AIU: Enabled
-===================================================================================== test session starts ======================================================================================
-platform linux -- Python 3.12.5, pytest-8.3.5, pluggy-1.5.0 -- /usr/bin/python3.12
-cachedir: .pytest_cache
-rootdir: /tmp/aiu-fms-testing-utils
-plugins: durations-1.5.2, env-1.1.5
-collected 3 items
-
-tests/models/test_model_expectations.py::TestAIUDecoderModels::test_model_output[/tmp/models/mistralai/Mistral-7B-Instruct-v0.3-False] <- ../foundation-model-stack/fms/testing/_internal/model_test_suite.py PASSED [ 33%]
-tests/models/test_model_expectations.py::TestAIUDecoderModels::test_model_weight_keys[/tmp/models/mistralai/Mistral-7B-Instruct-v0.3-False] <- ../foundation-model-stack/fms/testing/_internal/model_test_suite.py PASSED [ 66%]
-tests/models/test_model_expectations.py::TestAIUDecoderModels::test_model_unfused[/tmp/models/mistralai/Mistral-7B-Instruct-v0.3] SKIPPED (All AIU models are already unfused) [100%]
-
-===================================================================================== fixture duration top =====================================================================================
-total name num avg min
-0:00:02.201162 uninitialized_model 1 0:00:02.201162 0:00:02.201162
-0:00:00.051478 model 1 0:00:00.051478 0:00:00.051478
-0:00:02.252951 grand total 6 0:00:00.000135 0:00:00.000046
-==================================================================================== test call duration top ====================================================================================
-total name num avg min
-0:03:05.951278 TestAIUDecoderModels::test_model_output 1 0:03:05.951278 0:03:05.951278
-0:03:05.954470 grand total 3 0:00:00.003095 0:00:00.000097
-=================================================================================== test setup duration top ====================================================================================
-total name num avg min
-0:00:00.002004 grand total 3 0:00:00.000289 0:00:00.000102
-================================================================================== test teardown duration top ==================================================================================
-total name num avg min
-0:00:00.000363 grand total 3 0:00:00.000090 0:00:00.000077
-=========================================================================== 2 passed, 1 skipped in 189.01s (0:03:09) ===========================================================================
-
-```
-
-In this case, the model tested was a decoder model with a single output, the `TestAIUDecoderModels` is the most important case. In the next section, check the applicability for the [TestAIUModelsTupleOutput](./README.md#case-of-multiple-output---testaiumodelstupleoutput) cases.
-
-#### Case of multiple output - TestAIUModelsTupleOutput
-
-The case **TestAIUModelsTupleOutput** is applicable if the model being tested has output of more than one tensor. Like the model in the example default [tuple_output_models](./models/test_model_expectations.py#L76), is a RoBERTa model that can output in this different format.
-
-- Add the model also to [tuple_output_models](./models/test_model_expectations.py#L76).
-- 4.1 Run `pytest tests/models/test_model_expectations.py::TestAIUModelsTupleOutput --capture_expectation` to save the model weights;
-
-```bash
-tests/models/test_model_expectations.py::TestAIUModelsTupleOutput::test_model_output[/ibm-dmf/models/watsonx/shared/granite-20b-code-cobol-v1/20240603-False] <- ../foundation-model-stack/fms/testing/_internal/model_test_suite.py PASSED [ 66%]
-tests/models/test_model_expectations.py::TestAIUModelsTupleOutput::test_model_weight_keys[/ibm-dmf/models/watsonx/shared/granite-20b-code-cobol-v1/20240603-False] <- ../foundation-model-stack/fms/testing/_internal/model_test_suite.py PASSED [ 83%]
-tests/models/test_model_expectations.py::TestAIUModelsTupleOutput::test_model_unfused[/ibm-dmf/models/watsonx/shared/granite-20b-code-cobol-v1/20240603] SKIPPED [100%]
-```
-
-> When adding new models expectations, please include in the PR with capture expectation tests added, the date of the image used to generate the file.
-
-Check this example of a PR for adding a new model expectations' files and results [here](https://github.com/foundation-model-stack/aiu-fms-testing-utils/pull/48).
\ No newline at end of file
+1. [Generate and run the model tests in this repo](./MODEL.md#model-tests)
+2. [Generate metrics by layer](./LAYERS.md#layer-metrics-generation)
diff --git a/tests/resources/assets/metrics_fms_deepview_integration.full.png b/tests/resources/assets/metrics_fms_deepview_integration.full.png
new file mode 100644
index 00000000..f8822dff
Binary files /dev/null and b/tests/resources/assets/metrics_fms_deepview_integration.full.png differ
diff --git a/tests/resources/assets/metrics_fms_deepview_integration.zoom.png b/tests/resources/assets/metrics_fms_deepview_integration.zoom.png
new file mode 100644
index 00000000..4953ece7
Binary files /dev/null and b/tests/resources/assets/metrics_fms_deepview_integration.zoom.png differ
diff --git a/tests/resources/assets/metrics_generation_layers.png b/tests/resources/assets/metrics_generation_layers.png
new file mode 100644
index 00000000..ca783141
Binary files /dev/null and b/tests/resources/assets/metrics_generation_layers.png differ
diff --git a/tests/resources/get_thresholds.py b/tests/resources/get_thresholds.py
index 7dedb70c..3e820246 100644
--- a/tests/resources/get_thresholds.py
+++ b/tests/resources/get_thresholds.py
@@ -1,8 +1,18 @@
import glob
-import os
import numpy as np
import argparse
import os
+import re
+
+import logging
+
+import json
+
+from aiu_fms_testing_utils.utils.metrics_utils import abs_diff_linalg_norm, list_mean
+
+logger = logging.getLogger(__name__)
+LOG_LEVEL = os.environ.get('LOG_LEVEL', 'INFO').upper()
+logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s %(message)s")
parser = argparse.ArgumentParser(
description="Script to get thresholds metrics"
@@ -22,7 +32,7 @@
default=[],
nargs='+',
required=True,
- help="List of metrics separated by space. Eg.: diff_mean ce",
+ help="List of metrics separated by space. Eg. for full model mode: diff_mean ce | Eg. for layers mode: abs_diff cos_sim_avg cos_sim_mean",
)
parser.add_argument(
"--file_base",
@@ -31,22 +41,102 @@
required=True,
help="Path where the thresholds output from the generate_metrics.py script were stored.",
)
+parser.add_argument(
+ "--layer_io",
+ action="store_true",
+ help="Sets the metric generation mode to layers IO"
+)
+parser.add_argument(
+ "--output_path",
+ type=str,
+ default="/tmp/aiu-fms-testing-utils/output",
+ required=False,
+ help="Path where the json thresholds output for the layers will be saved.",
+)
args = parser.parse_args()
models = [model.replace("/", "--") for model in args.models]
metrics = [metric for metric in args.metrics]
file_base = args.file_base
+layer_mode = args.file_base if args.file_base else False
+generate_mode_pattern = r"\.(iter-)([0-9]+)"
+
+
+def load_metric_file(file_path, layer_header, values):
+ """
+ Loads a metric file and returns its values as a list of floats.
+
+ Args:
+ file_path (str): The path to the metric file.
+ layer_header (bool): Whether to skip the first three lines of the file. Default is False.
+ values (list): Metrics values list
+
+ Returns:
+ list[float]: A list of metric values read from the file.
+ """
+ try:
+ with open(file_path, "r") as file:
+ if layer_header:
+ for _ in range(3):
+ next(file)
+ else:
+ next(file) # skip single header
+ for line in file:
+ values.append(float(line))
+ except StopIteration:
+ logger.info("Path empty or no more metric files found.")
+ pass
+ return values
for model in models:
+ result_dict = {"model_id": model}
for metric in metrics:
- path = os.path.join(file_base, f"{model}*{metric}*.csv")
+ metric_name = "_".join(metric.split("_")[:2]) if layer_mode else metric
+ path = os.path.join(file_base, f"{model}*{metric_name}*.csv")
metric_files = glob.glob(path)
+ result_dict[metric] = {}
+ if not layer_mode:
+ metric_list = []
+ for metric_file in metric_files:
+ metric_list = load_metric_file(metric_file, layer_mode, metric_list)
+ logger.info(f"found {len(metric_files)} metric files")
+ logger.info(model, metric, np.percentile(metric_list, 99.0))
+ else:
+ layers = {}
+ for metric_file in metric_files:
+ layer_dict = {}
+ metric_layer_list = []
+ layer_name = metric_file.split("--")[-1].replace(".{}".format(metric_name), "")
+ layer_name = layer_name.replace(".csv","")
+ metric_layer_list = load_metric_file(metric_file, layer_mode, metric_layer_list)
+ if re.search(generate_mode_pattern, layer_name):
+ layer_name = re.sub(generate_mode_pattern, "", layer_name)
+ if not layer_name in layers.keys():
+ layers[layer_name] = metric_layer_list
+ else:
+ layers[layer_name].extend(metric_layer_list)
+ logger.debug(f"Output layer with generate mode {layer_name}")
+ else:
+ layer_dict[layer_name] = metric_layer_list
+ logger.debug(f"Output layer {layer_name}")
+ layers.update(layer_dict)
+ logger.info(f"found {len(metric_files)} layers metric files")
- metric_list = []
- for metric_file in metric_files:
+ for key, l in layers.items():
+ l = np.nan_to_num(l, nan=0.0)
+ if "abs_diff" in metric:
+ metric_val = abs_diff_linalg_norm(l)
+ logger.info(f"Layer {key} abs_diff_linalg_norm = {metric_val}")
+ result_dict[metric][key] = metric_val
+ elif "avg" in metric:
+ metric_avg = np.average(l)
+ logger.info(f"Layer {key} {metric} = {metric_avg}")
+ result_dict[metric][key] = metric_avg
+ elif "mean" in metric:
+ metric_mean = list_mean(l)
+ logger.info(f"Layer {key} {metric} = {metric_mean}")
+ result_dict[metric][key] = metric_mean
- with open(metric_file, "r") as file:
- next(file)
- for line in file:
- metric_list.append(float(line))
- print(f"found {len(metric_files)} metric files")
- print(model, metric, np.percentile(metric_list, 99.0))
+ json_output_path = args.output_path if args.output_path else file_base
+ f_result_path = os.path.join(json_output_path, f"{model}-thresholds.json")
+ with open(f_result_path, 'w') as fp:
+ json.dump(result_dict, fp)
\ No newline at end of file
diff --git a/tests/resources/sample_layer_th.json b/tests/resources/sample_layer_th.json
new file mode 100644
index 00000000..30777a55
--- /dev/null
+++ b/tests/resources/sample_layer_th.json
@@ -0,0 +1,1461 @@
+{
+ "model_id": "ibm-granite--granite-3.2-8b-instruct",
+ "abs_diff": {
+ "model.base_model.embedding": 1.7258892434335918e-7,
+ "model.base_model.layers0.ln": 0.4083323414747196,
+ "model.base_model.layers0.attn.in_proj.query": 0.7099368339133884,
+ "model.base_model.layers0.attn.in_proj.key": 0.40915828503373886,
+ "model.base_model.layers0.attn.in_proj.value": 0.12381335209555287,
+ "model.base_model.layers0.attn.in_proj": 0.12381335209555287,
+ "model.base_model.layers0.attn.dense": 0.20848351243216826,
+ "model.base_model.layers0.ff_ln": 0.341461752477386,
+ "model.base_model.layers0.ff_sub_layer.wg": 0.567641596277766,
+ "model.base_model.layers0.ff_sub_layer.a": 0.22434598762517147,
+ "model.base_model.layers0.ff_sub_layer.w1": 0.6067234563084546,
+ "model.base_model.layers0.ff_sub_layer.w2": 0.26350909056410365,
+ "model.base_model.layers0.ff_sub_layer": 0.26350909056410365,
+ "model.base_model.layers1.ln": 0.8336723703984307,
+ "model.base_model.layers1.attn.in_proj.query": 0.8916629422172441,
+ "model.base_model.layers1.attn.in_proj.key": 0.9832392219817886,
+ "model.base_model.layers1.attn.in_proj.value": 0.19266429901472298,
+ "model.base_model.layers1.attn.in_proj": 0.19266429901472298,
+ "model.base_model.layers1.attn.dense": 0.22844734196163022,
+ "model.base_model.layers1.ff_ln": 0.5521520840910448,
+ "model.base_model.layers1.ff_sub_layer.wg": 0.7689933213289677,
+ "model.base_model.layers1.ff_sub_layer.a": 0.21604948169231603,
+ "model.base_model.layers1.ff_sub_layer.w1": 0.6776186104315166,
+ "model.base_model.layers1.ff_sub_layer.w2": 0.28674962163968093,
+ "model.base_model.layers1.ff_sub_layer": 0.28674962163968093,
+ "model.base_model.layers2.ln": 0.8847133906331596,
+ "model.base_model.layers2.attn.in_proj.query": 0.9420112644015102,
+ "model.base_model.layers2.attn.in_proj.key": 0.97354122886582,
+ "model.base_model.layers2.attn.in_proj.value": 0.22630204790296457,
+ "model.base_model.layers2.attn.in_proj": 0.22630204790296457,
+ "model.base_model.layers2.attn.dense": 0.18332329740252973,
+ "model.base_model.layers2.ff_ln": 0.525927330521208,
+ "model.base_model.layers2.ff_sub_layer.wg": 0.7853817552255021,
+ "model.base_model.layers2.ff_sub_layer.a": 0.2126492684120587,
+ "model.base_model.layers2.ff_sub_layer.w1": 0.7100478846115829,
+ "model.base_model.layers2.ff_sub_layer.w2": 0.47577485764898225,
+ "model.base_model.layers2.ff_sub_layer": 0.47577485764898225,
+ "model.base_model.layers3.ln": 1.187278312784196,
+ "model.base_model.layers3.attn.in_proj.query": 1.1850289651507167,
+ "model.base_model.layers3.attn.in_proj.key": 0.7688321305049673,
+ "model.base_model.layers3.attn.in_proj.value": 0.3176595376400436,
+ "model.base_model.layers3.attn.in_proj": 0.3176595376400436,
+ "model.base_model.layers3.attn.dense": 0.13923853417583815,
+ "model.base_model.layers3.ff_ln": 0.5475849962664296,
+ "model.base_model.layers3.ff_sub_layer.wg": 0.807950986423741,
+ "model.base_model.layers3.ff_sub_layer.a": 0.23481427667647428,
+ "model.base_model.layers3.ff_sub_layer.w1": 0.7336249578728383,
+ "model.base_model.layers3.ff_sub_layer.w2": 0.24303756003178248,
+ "model.base_model.layers3.ff_sub_layer": 0.24303756003178248,
+ "model.base_model.layers4.ln": 1.2737879599094148,
+ "model.base_model.layers4.attn.in_proj.query": 1.2522564221698715,
+ "model.base_model.layers4.attn.in_proj.key": 0.782955991324812,
+ "model.base_model.layers4.attn.in_proj.value": 0.3420645136385523,
+ "model.base_model.layers4.attn.in_proj": 0.3420645136385523,
+ "model.base_model.layers4.attn.dense": 0.1515373454933509,
+ "model.base_model.layers4.ff_ln": 0.569624540933297,
+ "model.base_model.layers4.ff_sub_layer.wg": 0.8330374753554995,
+ "model.base_model.layers4.ff_sub_layer.a": 0.2518642658931705,
+ "model.base_model.layers4.ff_sub_layer.w1": 0.7665822056062829,
+ "model.base_model.layers4.ff_sub_layer.w2": 0.49593585284155767,
+ "model.base_model.layers4.ff_sub_layer": 0.49593585284155767,
+ "model.base_model.layers5.ln": 1.5829002271302295,
+ "model.base_model.layers5.attn.in_proj.query": 1.38363090033142,
+ "model.base_model.layers5.attn.in_proj.key": 0.8454332397205699,
+ "model.base_model.layers5.attn.in_proj.value": 0.3697437891574192,
+ "model.base_model.layers5.attn.in_proj": 0.3697437891574192,
+ "model.base_model.layers5.attn.dense": 0.15887531995943088,
+ "model.base_model.layers5.ff_ln": 0.5912267834568247,
+ "model.base_model.layers5.ff_sub_layer.wg": 0.8652184969534987,
+ "model.base_model.layers5.ff_sub_layer.a": 0.2788222446063635,
+ "model.base_model.layers5.ff_sub_layer.w1": 0.805531576372046,
+ "model.base_model.layers5.ff_sub_layer.w2": 6.458161847668088,
+ "model.base_model.layers5.ff_sub_layer": 6.458161847668088,
+ "model.base_model.layers6.ln": 1.6201492274756133,
+ "model.base_model.layers6.attn.in_proj.query": 1.4744989120293985,
+ "model.base_model.layers6.attn.in_proj.key": 0.846251372966467,
+ "model.base_model.layers6.attn.in_proj.value": 0.5015866949449543,
+ "model.base_model.layers6.attn.in_proj": 0.5015866949449543,
+ "model.base_model.layers6.attn.dense": 0.16340297135545534,
+ "model.base_model.layers6.ff_ln": 0.6227574873739156,
+ "model.base_model.layers6.ff_sub_layer.wg": 0.9571147986366199,
+ "model.base_model.layers6.ff_sub_layer.a": 0.26415395545867043,
+ "model.base_model.layers6.ff_sub_layer.w1": 0.8342855779877698,
+ "model.base_model.layers6.ff_sub_layer.w2": 0.2796118565723263,
+ "model.base_model.layers6.ff_sub_layer": 0.2796118565723263,
+ "model.base_model.layers7.ln": 1.6159548963285146,
+ "model.base_model.layers7.attn.in_proj.query": 1.5228702271805528,
+ "model.base_model.layers7.attn.in_proj.key": 0.8960220712505957,
+ "model.base_model.layers7.attn.in_proj.value": 0.431180957807636,
+ "model.base_model.layers7.attn.in_proj": 0.431180957807636,
+ "model.base_model.layers7.attn.dense": 0.20139136499675042,
+ "model.base_model.layers7.ff_ln": 0.6423049997959157,
+ "model.base_model.layers7.ff_sub_layer.wg": 0.9256279202554698,
+ "model.base_model.layers7.ff_sub_layer.a": 0.2759643426245992,
+ "model.base_model.layers7.ff_sub_layer.w1": 0.875136757823107,
+ "model.base_model.layers7.ff_sub_layer.w2": 0.2839478439280297,
+ "model.base_model.layers7.ff_sub_layer": 0.2839478439280297,
+ "model.base_model.layers8.ln": 1.741672105462342,
+ "model.base_model.layers8.attn.in_proj.query": 1.5144931111825843,
+ "model.base_model.layers8.attn.in_proj.key": 0.9050611304626728,
+ "model.base_model.layers8.attn.in_proj.value": 0.45288345746031744,
+ "model.base_model.layers8.attn.in_proj": 0.45288345746031744,
+ "model.base_model.layers8.attn.dense": 0.2193785041754783,
+ "model.base_model.layers8.ff_ln": 0.6534704470539051,
+ "model.base_model.layers8.ff_sub_layer.wg": 0.8907470918723855,
+ "model.base_model.layers8.ff_sub_layer.a": 0.3020812444606932,
+ "model.base_model.layers8.ff_sub_layer.w1": 0.8787262506358846,
+ "model.base_model.layers8.ff_sub_layer.w2": 0.29793266450695177,
+ "model.base_model.layers8.ff_sub_layer": 0.29793266450695177,
+ "model.base_model.layers9.ln": 1.8567900810038824,
+ "model.base_model.layers9.attn.in_proj.query": 1.7709552081035318,
+ "model.base_model.layers9.attn.in_proj.key": 1.1152833545252818,
+ "model.base_model.layers9.attn.in_proj.value": 0.43904002787117086,
+ "model.base_model.layers9.attn.in_proj": 0.43904002787117086,
+ "model.base_model.layers9.attn.dense": 0.21161597281673783,
+ "model.base_model.layers9.ff_ln": 0.6781821123457642,
+ "model.base_model.layers9.ff_sub_layer.wg": 0.8942743139704491,
+ "model.base_model.layers9.ff_sub_layer.a": 0.32786965786847,
+ "model.base_model.layers9.ff_sub_layer.w1": 0.9299769907211495,
+ "model.base_model.layers9.ff_sub_layer.w2": 0.3286158026100792,
+ "model.base_model.layers9.ff_sub_layer": 0.3286158026100792,
+ "model.base_model.layers10.ln": 1.8202074364076473,
+ "model.base_model.layers10.attn.in_proj.query": 1.7732267575152894,
+ "model.base_model.layers10.attn.in_proj.key": 1.103206300113814,
+ "model.base_model.layers10.attn.in_proj.value": 0.43746372895795455,
+ "model.base_model.layers10.attn.in_proj": 0.43746372895795455,
+ "model.base_model.layers10.attn.dense": 0.2543700716338981,
+ "model.base_model.layers10.ff_ln": 0.6453063609710944,
+ "model.base_model.layers10.ff_sub_layer.wg": 0.868881275396978,
+ "model.base_model.layers10.ff_sub_layer.a": 0.38873500469467664,
+ "model.base_model.layers10.ff_sub_layer.w1": 0.9129176944690793,
+ "model.base_model.layers10.ff_sub_layer.w2": 0.364084027180546,
+ "model.base_model.layers10.ff_sub_layer": 0.364084027180546,
+ "model.base_model.layers11.ln": 2.038882746034096,
+ "model.base_model.layers11.attn.in_proj.query": 1.8738892559126075,
+ "model.base_model.layers11.attn.in_proj.key": 1.1812775777597084,
+ "model.base_model.layers11.attn.in_proj.value": 0.5292393958566379,
+ "model.base_model.layers11.attn.in_proj": 0.5292393958566379,
+ "model.base_model.layers11.attn.dense": 0.25136630607654065,
+ "model.base_model.layers11.ff_ln": 0.6411682302448929,
+ "model.base_model.layers11.ff_sub_layer.wg": 0.8503220759746339,
+ "model.base_model.layers11.ff_sub_layer.a": 0.41617093713712366,
+ "model.base_model.layers11.ff_sub_layer.w1": 0.9179281673445617,
+ "model.base_model.layers11.ff_sub_layer.w2": 0.37414450486896306,
+ "model.base_model.layers11.ff_sub_layer": 0.37414450486896306,
+ "model.base_model.layers12.ln": 2.4163706102336215,
+ "model.base_model.layers12.attn.in_proj.query": 2.0500510589305545,
+ "model.base_model.layers12.attn.in_proj.key": 1.30855177359395,
+ "model.base_model.layers12.attn.in_proj.value": 0.5560739862554699,
+ "model.base_model.layers12.attn.in_proj": 0.5560739862554699,
+ "model.base_model.layers12.attn.dense": 0.23275576971967807,
+ "model.base_model.layers12.ff_ln": 0.6883567718975563,
+ "model.base_model.layers12.ff_sub_layer.wg": 0.9019772927528285,
+ "model.base_model.layers12.ff_sub_layer.a": 0.45307173712771465,
+ "model.base_model.layers12.ff_sub_layer.w1": 0.9869529936984363,
+ "model.base_model.layers12.ff_sub_layer.w2": 0.4116728794564218,
+ "model.base_model.layers12.ff_sub_layer": 0.4116728794564218,
+ "model.base_model.layers13.ln": 2.2532344689730888,
+ "model.base_model.layers13.attn.in_proj.query": 1.986386265072094,
+ "model.base_model.layers13.attn.in_proj.key": 1.2302656755221903,
+ "model.base_model.layers13.attn.in_proj.value": 0.6760934663246982,
+ "model.base_model.layers13.attn.in_proj": 0.6760934663246982,
+ "model.base_model.layers13.attn.dense": 0.29404449729484533,
+ "model.base_model.layers13.ff_ln": 0.7130878630526057,
+ "model.base_model.layers13.ff_sub_layer.wg": 0.937333578785468,
+ "model.base_model.layers13.ff_sub_layer.a": 0.48989951904086926,
+ "model.base_model.layers13.ff_sub_layer.w1": 1.037001866367909,
+ "model.base_model.layers13.ff_sub_layer.w2": 0.45412857056087247,
+ "model.base_model.layers13.ff_sub_layer": 0.45412857056087247,
+ "model.base_model.layers14.ln": 2.437243134374781,
+ "model.base_model.layers14.attn.in_proj.query": 2.0743171823413564,
+ "model.base_model.layers14.attn.in_proj.key": 1.2871035522680647,
+ "model.base_model.layers14.attn.in_proj.value": 0.8113716584732076,
+ "model.base_model.layers14.attn.in_proj": 0.8113716584732076,
+ "model.base_model.layers14.attn.dense": 0.3208586103408894,
+ "model.base_model.layers14.ff_ln": 0.74594178340782,
+ "model.base_model.layers14.ff_sub_layer.wg": 0.9778100859617569,
+ "model.base_model.layers14.ff_sub_layer.a": 0.5209799969126123,
+ "model.base_model.layers14.ff_sub_layer.w1": 1.0920826711945681,
+ "model.base_model.layers14.ff_sub_layer.w2": 0.4954137823241344,
+ "model.base_model.layers14.ff_sub_layer": 0.4954137823241344,
+ "model.base_model.layers15.ln": 2.4062954038161326,
+ "model.base_model.layers15.attn.in_proj.query": 2.1558424705963386,
+ "model.base_model.layers15.attn.in_proj.key": 1.330842572889033,
+ "model.base_model.layers15.attn.in_proj.value": 0.7648466357382719,
+ "model.base_model.layers15.attn.in_proj": 0.7648466357382719,
+ "model.base_model.layers15.attn.dense": 0.3685467097914105,
+ "model.base_model.layers15.ff_ln": 0.7865979742947545,
+ "model.base_model.layers15.ff_sub_layer.wg": 1.0241376498806298,
+ "model.base_model.layers15.ff_sub_layer.a": 0.5586084616481519,
+ "model.base_model.layers15.ff_sub_layer.w1": 1.1498353263767294,
+ "model.base_model.layers15.ff_sub_layer.w2": 0.5425247171814839,
+ "model.base_model.layers15.ff_sub_layer": 0.5425247171814839,
+ "model.base_model.layers16.ln": 2.4443353059116983,
+ "model.base_model.layers16.attn.in_proj.query": 2.269587235946433,
+ "model.base_model.layers16.attn.in_proj.key": 1.4537761152294166,
+ "model.base_model.layers16.attn.in_proj.value": 0.7468716348790646,
+ "model.base_model.layers16.attn.in_proj": 0.7468716348790646,
+ "model.base_model.layers16.attn.dense": 0.4155567536426726,
+ "model.base_model.layers16.ff_ln": 0.8308677187547001,
+ "model.base_model.layers16.ff_sub_layer.wg": 1.0735924517320787,
+ "model.base_model.layers16.ff_sub_layer.a": 0.5551793673332353,
+ "model.base_model.layers16.ff_sub_layer.w1": 1.213821160353066,
+ "model.base_model.layers16.ff_sub_layer.w2": 0.5594797384921322,
+ "model.base_model.layers16.ff_sub_layer": 0.5594797384921322,
+ "model.base_model.layers17.ln": 2.6512140391853403,
+ "model.base_model.layers17.attn.in_proj.query": 2.3487128081705113,
+ "model.base_model.layers17.attn.in_proj.key": 1.4695542584316446,
+ "model.base_model.layers17.attn.in_proj.value": 0.7743666293930211,
+ "model.base_model.layers17.attn.in_proj": 0.7743666293930211,
+ "model.base_model.layers17.attn.dense": 0.33041246837038185,
+ "model.base_model.layers17.ff_ln": 0.895327132194183,
+ "model.base_model.layers17.ff_sub_layer.wg": 1.1225943284319686,
+ "model.base_model.layers17.ff_sub_layer.a": 0.5179989262569722,
+ "model.base_model.layers17.ff_sub_layer.w1": 1.2977184215107684,
+ "model.base_model.layers17.ff_sub_layer.w2": 0.533179908783865,
+ "model.base_model.layers17.ff_sub_layer": 0.533179908783865,
+ "model.base_model.layers18.ln": 2.4459591363505186,
+ "model.base_model.layers18.attn.in_proj.query": 2.2373733723308997,
+ "model.base_model.layers18.attn.in_proj.key": 1.400283822390989,
+ "model.base_model.layers18.attn.in_proj.value": 0.7509184895587986,
+ "model.base_model.layers18.attn.in_proj": 0.7509184895587986,
+ "model.base_model.layers18.attn.dense": 0.32375604539806324,
+ "model.base_model.layers18.ff_ln": 0.9078149707342478,
+ "model.base_model.layers18.ff_sub_layer.wg": 1.131303164011437,
+ "model.base_model.layers18.ff_sub_layer.a": 0.5232468505499002,
+ "model.base_model.layers18.ff_sub_layer.w1": 1.2913645074679432,
+ "model.base_model.layers18.ff_sub_layer.w2": 0.5312936297121092,
+ "model.base_model.layers18.ff_sub_layer": 0.5312936297121092,
+ "model.base_model.layers19.ln": 2.676787761099476,
+ "model.base_model.layers19.attn.in_proj.query": 2.3393956932132536,
+ "model.base_model.layers19.attn.in_proj.key": 1.4064832828052416,
+ "model.base_model.layers19.attn.in_proj.value": 0.7740809612930063,
+ "model.base_model.layers19.attn.in_proj": 0.7740809612930063,
+ "model.base_model.layers19.attn.dense": 0.3794086774054501,
+ "model.base_model.layers19.ff_ln": 0.9050935588868597,
+ "model.base_model.layers19.ff_sub_layer.wg": 1.132639379322763,
+ "model.base_model.layers19.ff_sub_layer.a": 0.5148442761709476,
+ "model.base_model.layers19.ff_sub_layer.w1": 1.2685600336469933,
+ "model.base_model.layers19.ff_sub_layer.w2": 0.5907017421871078,
+ "model.base_model.layers19.ff_sub_layer": 0.5907017421871078,
+ "model.base_model.layers20.ln": 0.05869723235891928,
+ "model.base_model.layers20.attn.in_proj.query": 5.875081363292194e-12,
+ "model.base_model.layers20.attn.in_proj.key": 3.1874520397371848e-12,
+ "model.base_model.layers20.attn.in_proj.value": 0.019149133693904922,
+ "model.base_model.layers20.attn.in_proj": 0.019149133693904922,
+ "model.base_model.layers20.attn.dense": 0.01650506472777983,
+ "model.base_model.layers20.ff_ln": 0.9039704212821755,
+ "model.base_model.layers20.ff_sub_layer.wg": 1.1358363761297412,
+ "model.base_model.layers20.ff_sub_layer.a": 0.49959621499860896,
+ "model.base_model.layers20.ff_sub_layer.w1": 1.2331316666130403,
+ "model.base_model.layers20.ff_sub_layer.w2": 0.5059322881960626,
+ "model.base_model.layers20.ff_sub_layer": 0.5059322881960626,
+ "model.base_model.layers21.ln": 2.683327342257717,
+ "model.base_model.layers21.attn.in_proj.query": 2.289205190674962,
+ "model.base_model.layers21.attn.in_proj.key": 1.4363975749387148,
+ "model.base_model.layers21.attn.in_proj.value": 0.7381585315543487,
+ "model.base_model.layers21.attn.in_proj": 0.7381585315543487,
+ "model.base_model.layers21.attn.dense": 0.35251317378922625,
+ "model.base_model.layers21.ff_ln": 0.9113812385779858,
+ "model.base_model.layers21.ff_sub_layer.wg": 1.1491777317149507,
+ "model.base_model.layers21.ff_sub_layer.a": 0.49914745963202006,
+ "model.base_model.layers21.ff_sub_layer.w1": 1.237055949078281,
+ "model.base_model.layers21.ff_sub_layer.w2": 0.5369818215476636,
+ "model.base_model.layers21.ff_sub_layer": 0.5369818215476636,
+ "model.base_model.layers22.ln": 2.6452733719167996,
+ "model.base_model.layers22.attn.in_proj.query": 2.106319142532332,
+ "model.base_model.layers22.attn.in_proj.key": 1.2702739611235492,
+ "model.base_model.layers22.attn.in_proj.value": 0.7650688708442046,
+ "model.base_model.layers22.attn.in_proj": 0.7650688708442046,
+ "model.base_model.layers22.attn.dense": 0.22197807736713943,
+ "model.base_model.layers22.ff_ln": 0.9170507180633947,
+ "model.base_model.layers22.ff_sub_layer.wg": 1.1649007558621947,
+ "model.base_model.layers22.ff_sub_layer.a": 0.5025733044419384,
+ "model.base_model.layers22.ff_sub_layer.w1": 1.2365195453856566,
+ "model.base_model.layers22.ff_sub_layer.w2": 0.5207189861586834,
+ "model.base_model.layers22.ff_sub_layer": 0.5207189861586834,
+ "model.base_model.layers23.ln": 0.04076474676222718,
+ "model.base_model.layers23.attn.in_proj.query": 4.3584737606226e-12,
+ "model.base_model.layers23.attn.in_proj.key": 2.2713250241549265e-12,
+ "model.base_model.layers23.attn.in_proj.value": 0.011164386271087137,
+ "model.base_model.layers23.attn.in_proj": 0.011164386271087137,
+ "model.base_model.layers23.attn.dense": 0.008722732365763724,
+ "model.base_model.layers23.ff_ln": 0.9530040511329709,
+ "model.base_model.layers23.ff_sub_layer.wg": 1.2122010602800901,
+ "model.base_model.layers23.ff_sub_layer.a": 0.5210196550171168,
+ "model.base_model.layers23.ff_sub_layer.w1": 1.2721844742915902,
+ "model.base_model.layers23.ff_sub_layer.w2": 0.5207852088064424,
+ "model.base_model.layers23.ff_sub_layer": 0.5207852088064424,
+ "model.base_model.layers24.ln": 0.04148058891449975,
+ "model.base_model.layers24.attn.in_proj.query": 4.350413377985654e-12,
+ "model.base_model.layers24.attn.in_proj.key": 2.310924744916744e-12,
+ "model.base_model.layers24.attn.in_proj.value": 0.011669863540726839,
+ "model.base_model.layers24.attn.in_proj": 0.011669863540726839,
+ "model.base_model.layers24.attn.dense": 0.009749030890235091,
+ "model.base_model.layers24.ff_ln": 0.9209732334431034,
+ "model.base_model.layers24.ff_sub_layer.wg": 1.1783206389277576,
+ "model.base_model.layers24.ff_sub_layer.a": 0.4974131186041296,
+ "model.base_model.layers24.ff_sub_layer.w1": 1.2282531957175757,
+ "model.base_model.layers24.ff_sub_layer.w2": 0.4877984638344186,
+ "model.base_model.layers24.ff_sub_layer": 0.4877984638344186,
+ "model.base_model.layers25.ln": 2.6914940671608956,
+ "model.base_model.layers25.attn.in_proj.query": 2.079996666484281,
+ "model.base_model.layers25.attn.in_proj.key": 1.2256532914682756,
+ "model.base_model.layers25.attn.in_proj.value": 0.8446561344670284,
+ "model.base_model.layers25.attn.in_proj": 0.8446561344670284,
+ "model.base_model.layers25.attn.dense": 0.23142293885894077,
+ "model.base_model.layers25.ff_ln": 0.9550253005897409,
+ "model.base_model.layers25.ff_sub_layer.wg": 1.2256491705546648,
+ "model.base_model.layers25.ff_sub_layer.a": 0.5235781749861929,
+ "model.base_model.layers25.ff_sub_layer.w1": 1.2707070667436549,
+ "model.base_model.layers25.ff_sub_layer.w2": 0.5201997339672954,
+ "model.base_model.layers25.ff_sub_layer": 0.5201997339672954,
+ "model.base_model.layers26.ln": 0.04852477119171675,
+ "model.base_model.layers26.attn.in_proj.query": 5.2433691181258795e-12,
+ "model.base_model.layers26.attn.in_proj.key": 2.6993943331679985e-12,
+ "model.base_model.layers26.attn.in_proj.value": 0.013973908680838854,
+ "model.base_model.layers26.attn.in_proj": 0.013973908680838854,
+ "model.base_model.layers26.attn.dense": 0.011292771432111815,
+ "model.base_model.layers26.ff_ln": 0.9791606070877357,
+ "model.base_model.layers26.ff_sub_layer.wg": 1.2557571186952852,
+ "model.base_model.layers26.ff_sub_layer.a": 0.5510265953756356,
+ "model.base_model.layers26.ff_sub_layer.w1": 1.299429319889889,
+ "model.base_model.layers26.ff_sub_layer.w2": 0.5547229432942743,
+ "model.base_model.layers26.ff_sub_layer": 0.5547229432942743,
+ "model.base_model.layers27.ln": 0.05761726481170151,
+ "model.base_model.layers27.attn.in_proj.query": 6.2990615182407684e-12,
+ "model.base_model.layers27.attn.in_proj.key": 3.563824771312352e-12,
+ "model.base_model.layers27.attn.in_proj.value": 0.01689026442218174,
+ "model.base_model.layers27.attn.in_proj": 0.01689026442218174,
+ "model.base_model.layers27.attn.dense": 0.014086568839176132,
+ "model.base_model.layers27.ff_ln": 1.0254311249358894,
+ "model.base_model.layers27.ff_sub_layer.wg": 1.3301749696605942,
+ "model.base_model.layers27.ff_sub_layer.a": 0.57086262231711,
+ "model.base_model.layers27.ff_sub_layer.w1": 1.3588965368445045,
+ "model.base_model.layers27.ff_sub_layer.w2": 0.5754856665910341,
+ "model.base_model.layers27.ff_sub_layer": 0.5754856665910341,
+ "model.base_model.layers28.ln": 2.815465046774573,
+ "model.base_model.layers28.attn.in_proj.query": 2.089838276603043,
+ "model.base_model.layers28.attn.in_proj.key": 1.2063263141816927,
+ "model.base_model.layers28.attn.in_proj.value": 0.9208380969372306,
+ "model.base_model.layers28.attn.in_proj": 0.9208380969372306,
+ "model.base_model.layers28.attn.dense": 0.3453570196775779,
+ "model.base_model.layers28.ff_ln": 1.0665693068297435,
+ "model.base_model.layers28.ff_sub_layer.wg": 1.3650839635469527,
+ "model.base_model.layers28.ff_sub_layer.a": 0.6022059546246795,
+ "model.base_model.layers28.ff_sub_layer.w1": 1.4164889273401393,
+ "model.base_model.layers28.ff_sub_layer.w2": 0.6325165800911244,
+ "model.base_model.layers28.ff_sub_layer": 0.6325165800911244,
+ "model.base_model.layers29.ln": 0.10897941348773572,
+ "model.base_model.layers29.attn.in_proj.query": 1.3703584847019003e-11,
+ "model.base_model.layers29.attn.in_proj.key": 1.0041752848134295e-11,
+ "model.base_model.layers29.attn.in_proj.value": 0.03577135277541929,
+ "model.base_model.layers29.attn.in_proj": 0.03577135277541929,
+ "model.base_model.layers29.attn.dense": 0.03097476926018952,
+ "model.base_model.layers29.ff_ln": 1.1029617448772335,
+ "model.base_model.layers29.ff_sub_layer.wg": 1.4101410421393201,
+ "model.base_model.layers29.ff_sub_layer.a": 0.6352038645951448,
+ "model.base_model.layers29.ff_sub_layer.w1": 1.471897271218638,
+ "model.base_model.layers29.ff_sub_layer.w2": 0.706091258163473,
+ "model.base_model.layers29.ff_sub_layer": 0.706091258163473,
+ "model.base_model.layers30.ln": 3.0577601477991747,
+ "model.base_model.layers30.attn.in_proj.query": 2.2129892989404607,
+ "model.base_model.layers30.attn.in_proj.key": 1.2527367462035874,
+ "model.base_model.layers30.attn.in_proj.value": 1.0227985561921364,
+ "model.base_model.layers30.attn.in_proj": 1.0227985561921364,
+ "model.base_model.layers30.attn.dense": 0.4485761939015526,
+ "model.base_model.layers30.ff_ln": 1.1625355613262403,
+ "model.base_model.layers30.ff_sub_layer.wg": 1.477637721431356,
+ "model.base_model.layers30.ff_sub_layer.a": 0.6841802409270257,
+ "model.base_model.layers30.ff_sub_layer.w1": 1.5593898954122012,
+ "model.base_model.layers30.ff_sub_layer.w2": 0.7938874280695967,
+ "model.base_model.layers30.ff_sub_layer": 0.7938874280695967,
+ "model.base_model.layers31.ln": 3.130117016136881,
+ "model.base_model.layers31.attn.in_proj.query": 2.397979490267338,
+ "model.base_model.layers31.attn.in_proj.key": 1.4083771778423888,
+ "model.base_model.layers31.attn.in_proj.value": 1.0687232303788325,
+ "model.base_model.layers31.attn.in_proj": 1.0687232303788325,
+ "model.base_model.layers31.attn.dense": 0.5703270227263976,
+ "model.base_model.layers31.ff_ln": 1.2351666857681705,
+ "model.base_model.layers31.ff_sub_layer.wg": 1.5622753438177648,
+ "model.base_model.layers31.ff_sub_layer.a": 0.7138770555459574,
+ "model.base_model.layers31.ff_sub_layer.w1": 1.6589893255729982,
+ "model.base_model.layers31.ff_sub_layer.w2": 0.8682725145696991,
+ "model.base_model.layers31.ff_sub_layer": 0.8682725145696991,
+ "model.base_model.layers32.ln": 3.510177576653432,
+ "model.base_model.layers32.attn.in_proj.query": 2.6003867116961072,
+ "model.base_model.layers32.attn.in_proj.key": 1.4354428524553307,
+ "model.base_model.layers32.attn.in_proj.value": 1.2530256301596587,
+ "model.base_model.layers32.attn.in_proj": 1.2530256301596587,
+ "model.base_model.layers32.attn.dense": 0.5613037959531225,
+ "model.base_model.layers32.ff_ln": 1.3309315632753136,
+ "model.base_model.layers32.ff_sub_layer.wg": 1.684296329380358,
+ "model.base_model.layers32.ff_sub_layer.a": 0.807838512093823,
+ "model.base_model.layers32.ff_sub_layer.w1": 1.813825138245955,
+ "model.base_model.layers32.ff_sub_layer.w2": 1.1140093855705708,
+ "model.base_model.layers32.ff_sub_layer": 1.1140093855705708,
+ "model.base_model.layers33.ln": 3.3994719056638307,
+ "model.base_model.layers33.attn.in_proj.query": 2.459159138834412,
+ "model.base_model.layers33.attn.in_proj.key": 1.424937517391673,
+ "model.base_model.layers33.attn.in_proj.value": 1.1784484007100327,
+ "model.base_model.layers33.attn.in_proj": 1.1784484007100327,
+ "model.base_model.layers33.attn.dense": 0.7541024870140104,
+ "model.base_model.layers33.ff_ln": 1.4271515808182542,
+ "model.base_model.layers33.ff_sub_layer.wg": 1.7950767244829007,
+ "model.base_model.layers33.ff_sub_layer.a": 0.9005689213874422,
+ "model.base_model.layers33.ff_sub_layer.w1": 1.9605562935520227,
+ "model.base_model.layers33.ff_sub_layer.w2": 1.371932753468498,
+ "model.base_model.layers33.ff_sub_layer": 1.371932753468498,
+ "model.base_model.layers34.ln": 3.6159380021416987,
+ "model.base_model.layers34.attn.in_proj.query": 2.66264911046287,
+ "model.base_model.layers34.attn.in_proj.key": 1.4587745530262968,
+ "model.base_model.layers34.attn.in_proj.value": 1.4644100078343885,
+ "model.base_model.layers34.attn.in_proj": 1.4644100078343885,
+ "model.base_model.layers34.attn.dense": 1.0422025219468418,
+ "model.base_model.layers34.ff_ln": 1.6099415828846035,
+ "model.base_model.layers34.ff_sub_layer.wg": 1.9851124323663396,
+ "model.base_model.layers34.ff_sub_layer.a": 0.9364633326762221,
+ "model.base_model.layers34.ff_sub_layer.w1": 2.225696279107312,
+ "model.base_model.layers34.ff_sub_layer.w2": 1.6401083323900605,
+ "model.base_model.layers34.ff_sub_layer": 1.6401083323900605,
+ "model.base_model.layers35.ln": 3.711278463660952,
+ "model.base_model.layers35.attn.in_proj.query": 2.561732873746854,
+ "model.base_model.layers35.attn.in_proj.key": 1.294329713724366,
+ "model.base_model.layers35.attn.in_proj.value": 1.5323461483923608,
+ "model.base_model.layers35.attn.in_proj": 1.5323461483923608,
+ "model.base_model.layers35.attn.dense": 1.0954627428383212,
+ "model.base_model.layers35.ff_ln": 1.775861826807536,
+ "model.base_model.layers35.ff_sub_layer.wg": 2.161838614440668,
+ "model.base_model.layers35.ff_sub_layer.a": 1.0088421067208395,
+ "model.base_model.layers35.ff_sub_layer.w1": 2.4813606416689082,
+ "model.base_model.layers35.ff_sub_layer.w2": 2.010088221844646,
+ "model.base_model.layers35.ff_sub_layer": 2.010088221844646,
+ "model.base_model.layers36.ln": 3.8602793303041363,
+ "model.base_model.layers36.attn.in_proj.query": 2.6063705879063876,
+ "model.base_model.layers36.attn.in_proj.key": 1.2474004703198267,
+ "model.base_model.layers36.attn.in_proj.value": 1.9287609402199686,
+ "model.base_model.layers36.attn.in_proj": 1.9287609402199686,
+ "model.base_model.layers36.attn.dense": 1.1947622565945153,
+ "model.base_model.layers36.ff_ln": 2.079037610574989,
+ "model.base_model.layers36.ff_sub_layer.wg": 2.525487608494635,
+ "model.base_model.layers36.ff_sub_layer.a": 1.0332693238413477,
+ "model.base_model.layers36.ff_sub_layer.w1": 2.8818690519378407,
+ "model.base_model.layers36.ff_sub_layer.w2": 2.395313843284508,
+ "model.base_model.layers36.ff_sub_layer": 2.395313843284508,
+ "model.base_model.layers37.ln": 4.186380825760128,
+ "model.base_model.layers37.attn.in_proj.query": 2.672298526408581,
+ "model.base_model.layers37.attn.in_proj.key": 1.2635284782754816,
+ "model.base_model.layers37.attn.in_proj.value": 2.178426502393175,
+ "model.base_model.layers37.attn.in_proj": 2.178426502393175,
+ "model.base_model.layers37.attn.dense": 1.4376572090517186,
+ "model.base_model.layers37.ff_ln": 2.1801742391103995,
+ "model.base_model.layers37.ff_sub_layer.wg": 2.641452404906007,
+ "model.base_model.layers37.ff_sub_layer.a": 1.214881662670414,
+ "model.base_model.layers37.ff_sub_layer.w1": 3.0441249284997247,
+ "model.base_model.layers37.ff_sub_layer.w2": 3.3503320529626173,
+ "model.base_model.layers37.ff_sub_layer": 3.3503320529626173,
+ "model.base_model.layers38.ln": 4.456169508374579,
+ "model.base_model.layers38.attn.in_proj.query": 2.731806961679497,
+ "model.base_model.layers38.attn.in_proj.key": 1.1605432361845325,
+ "model.base_model.layers38.attn.in_proj.value": 2.955616851145189,
+ "model.base_model.layers38.attn.in_proj": 2.955616851145189,
+ "model.base_model.layers38.attn.dense": 1.8238709428240343,
+ "model.base_model.layers38.ff_ln": 2.355534708045058,
+ "model.base_model.layers38.ff_sub_layer.wg": 2.893516055983882,
+ "model.base_model.layers38.ff_sub_layer.a": 1.4775752726302787,
+ "model.base_model.layers38.ff_sub_layer.w1": 3.406498283269355,
+ "model.base_model.layers38.ff_sub_layer.w2": 4.872777683615914,
+ "model.base_model.layers38.ff_sub_layer": 4.872777683615914,
+ "model.base_model.layers39.ln": 5.245948807426597,
+ "model.base_model.layers39.attn.in_proj.query": 3.301107035939244,
+ "model.base_model.layers39.attn.in_proj.key": 1.3283636663444258,
+ "model.base_model.layers39.attn.in_proj.value": 3.009072478653194,
+ "model.base_model.layers39.attn.in_proj": 3.009072478653194,
+ "model.base_model.layers39.attn.dense": 3.0974715116743585,
+ "model.base_model.layers39.ff_ln": 2.832660363940657,
+ "model.base_model.layers39.ff_sub_layer.wg": 3.5396036346806463,
+ "model.base_model.layers39.ff_sub_layer.a": 2.1432838753474726,
+ "model.base_model.layers39.ff_sub_layer.w1": 4.608670447299947,
+ "model.base_model.layers39.ff_sub_layer.w2": 38.01352590965104,
+ "model.base_model.layers39.ff_sub_layer": 38.01352590965104,
+ "model.base_model.dec_norm": 17.89632157170833,
+ "model.base_model": 9.772593552981837,
+ "model.head": 138.84959684098462
+ },
+ "cos_sim_avg": {
+ "model.base_model.embedding": 1.0000095833092928,
+ "model.base_model.layers0.ln": 1.0000184457749128,
+ "model.base_model.layers0.attn.in_proj.query": 0.9999569654464722,
+ "model.base_model.layers0.attn.in_proj.key": 1.000030318275094,
+ "model.base_model.layers0.attn.in_proj.value": 0.9999886471778154,
+ "model.base_model.layers0.attn.in_proj": 0.9999886471778154,
+ "model.base_model.layers0.attn.dense": 1.0000049602240324,
+ "model.base_model.layers0.ff_ln": 0.9999961135908961,
+ "model.base_model.layers0.ff_sub_layer.wg": 1.0000046007335186,
+ "model.base_model.layers0.ff_sub_layer.a": 1.0000147661194205,
+ "model.base_model.layers0.ff_sub_layer.w1": 1.0000530388206244,
+ "model.base_model.layers0.ff_sub_layer.w2": 0.9999505197629333,
+ "model.base_model.layers0.ff_sub_layer": 0.9999505197629333,
+ "model.base_model.layers1.ln": 0.9999880297109485,
+ "model.base_model.layers1.attn.in_proj.query": 0.9999853884801269,
+ "model.base_model.layers1.attn.in_proj.key": 0.9999964172020555,
+ "model.base_model.layers1.attn.in_proj.value": 0.999986432492733,
+ "model.base_model.layers1.attn.in_proj": 0.999986432492733,
+ "model.base_model.layers1.attn.dense": 0.9999998668208718,
+ "model.base_model.layers1.ff_ln": 1.0000135749578476,
+ "model.base_model.layers1.ff_sub_layer.wg": 1.0000185761600733,
+ "model.base_model.layers1.ff_sub_layer.a": 0.999986195936799,
+ "model.base_model.layers1.ff_sub_layer.w1": 1.000002909451723,
+ "model.base_model.layers1.ff_sub_layer.w2": 1.0000054109841585,
+ "model.base_model.layers1.ff_sub_layer": 1.0000054109841585,
+ "model.base_model.layers2.ln": 1.0000562984496355,
+ "model.base_model.layers2.attn.in_proj.query": 0.9999984446913004,
+ "model.base_model.layers2.attn.in_proj.key": 1.0000433661043644,
+ "model.base_model.layers2.attn.in_proj.value": 0.9999881014227867,
+ "model.base_model.layers2.attn.in_proj": 0.9999881014227867,
+ "model.base_model.layers2.attn.dense": 1.000006778165698,
+ "model.base_model.layers2.ff_ln": 1.0000063749030232,
+ "model.base_model.layers2.ff_sub_layer.wg": 1.0000359807163477,
+ "model.base_model.layers2.ff_sub_layer.a": 0.999996374361217,
+ "model.base_model.layers2.ff_sub_layer.w1": 0.9999919822439551,
+ "model.base_model.layers2.ff_sub_layer.w2": 1.0000033108517528,
+ "model.base_model.layers2.ff_sub_layer": 1.0000033108517528,
+ "model.base_model.layers3.ln": 1.0000399872660637,
+ "model.base_model.layers3.attn.in_proj.query": 0.9999705422669649,
+ "model.base_model.layers3.attn.in_proj.key": 1.0000282796099782,
+ "model.base_model.layers3.attn.in_proj.value": 0.9999693483114243,
+ "model.base_model.layers3.attn.in_proj": 0.9999693483114243,
+ "model.base_model.layers3.attn.dense": 1.0000234134495258,
+ "model.base_model.layers3.ff_ln": 0.9999883593991399,
+ "model.base_model.layers3.ff_sub_layer.wg": 0.9999996954575181,
+ "model.base_model.layers3.ff_sub_layer.a": 1.0000347420573235,
+ "model.base_model.layers3.ff_sub_layer.w1": 1.0000090897083282,
+ "model.base_model.layers3.ff_sub_layer.w2": 0.99996093288064,
+ "model.base_model.layers3.ff_sub_layer": 0.99996093288064,
+ "model.base_model.layers4.ln": 1.0000464990735054,
+ "model.base_model.layers4.attn.in_proj.query": 1.000026903115213,
+ "model.base_model.layers4.attn.in_proj.key": 0.999997797422111,
+ "model.base_model.layers4.attn.in_proj.value": 0.9999980982393026,
+ "model.base_model.layers4.attn.in_proj": 0.9999980982393026,
+ "model.base_model.layers4.attn.dense": 1.0000115931034088,
+ "model.base_model.layers4.ff_ln": 1.0000024493783712,
+ "model.base_model.layers4.ff_sub_layer.wg": 1.0000334102660418,
+ "model.base_model.layers4.ff_sub_layer.a": 1.0000202497467399,
+ "model.base_model.layers4.ff_sub_layer.w1": 1.0000194078311324,
+ "model.base_model.layers4.ff_sub_layer.w2": 1.0000301413238049,
+ "model.base_model.layers4.ff_sub_layer": 1.0000301413238049,
+ "model.base_model.layers5.ln": 0.999975698068738,
+ "model.base_model.layers5.attn.in_proj.query": 1.0000565703958273,
+ "model.base_model.layers5.attn.in_proj.key": 1.0000175070017576,
+ "model.base_model.layers5.attn.in_proj.value": 1.0000161584466696,
+ "model.base_model.layers5.attn.in_proj": 1.0000161584466696,
+ "model.base_model.layers5.attn.dense": 1.000037212856114,
+ "model.base_model.layers5.ff_ln": 0.9999989373609424,
+ "model.base_model.layers5.ff_sub_layer.wg": 0.9999949997290969,
+ "model.base_model.layers5.ff_sub_layer.a": 1.000015702098608,
+ "model.base_model.layers5.ff_sub_layer.w1": 0.9999757846817374,
+ "model.base_model.layers5.ff_sub_layer.w2": 1.0000016959384084,
+ "model.base_model.layers5.ff_sub_layer": 1.0000016959384084,
+ "model.base_model.layers6.ln": 1.0000160429626703,
+ "model.base_model.layers6.attn.in_proj.query": 0.9999967850744724,
+ "model.base_model.layers6.attn.in_proj.key": 0.9999998025596142,
+ "model.base_model.layers6.attn.in_proj.value": 1.0000286241993308,
+ "model.base_model.layers6.attn.in_proj": 1.0000286241993308,
+ "model.base_model.layers6.attn.dense": 1.0000265408307314,
+ "model.base_model.layers6.ff_ln": 1.0000047851353884,
+ "model.base_model.layers6.ff_sub_layer.wg": 1.0000188006088138,
+ "model.base_model.layers6.ff_sub_layer.a": 0.9999990398064256,
+ "model.base_model.layers6.ff_sub_layer.w1": 0.9999480145052075,
+ "model.base_model.layers6.ff_sub_layer.w2": 1.0000073416158557,
+ "model.base_model.layers6.ff_sub_layer": 1.0000073416158557,
+ "model.base_model.layers7.ln": 1.0000189878046513,
+ "model.base_model.layers7.attn.in_proj.query": 1.0000445498153567,
+ "model.base_model.layers7.attn.in_proj.key": 0.9999774135649204,
+ "model.base_model.layers7.attn.in_proj.value": 1.0000242739915848,
+ "model.base_model.layers7.attn.in_proj": 1.0000242739915848,
+ "model.base_model.layers7.attn.dense": 0.9999957242980599,
+ "model.base_model.layers7.ff_ln": 0.999947358854115,
+ "model.base_model.layers7.ff_sub_layer.wg": 0.999967728741467,
+ "model.base_model.layers7.ff_sub_layer.a": 1.000053628347814,
+ "model.base_model.layers7.ff_sub_layer.w1": 1.0000687642022967,
+ "model.base_model.layers7.ff_sub_layer.w2": 0.9999854220077395,
+ "model.base_model.layers7.ff_sub_layer": 0.9999854220077395,
+ "model.base_model.layers8.ln": 1.000000380910933,
+ "model.base_model.layers8.attn.in_proj.query": 0.9999749204143882,
+ "model.base_model.layers8.attn.in_proj.key": 0.9999820357188582,
+ "model.base_model.layers8.attn.in_proj.value": 0.9999982211738825,
+ "model.base_model.layers8.attn.in_proj": 0.9999982211738825,
+ "model.base_model.layers8.attn.dense": 1.0000164853408933,
+ "model.base_model.layers8.ff_ln": 0.9999812999740243,
+ "model.base_model.layers8.ff_sub_layer.wg": 0.9999558068811893,
+ "model.base_model.layers8.ff_sub_layer.a": 0.9999397126957774,
+ "model.base_model.layers8.ff_sub_layer.w1": 0.9999668747186661,
+ "model.base_model.layers8.ff_sub_layer.w2": 0.9999804692342877,
+ "model.base_model.layers8.ff_sub_layer": 0.9999804692342877,
+ "model.base_model.layers9.ln": 1.0000508911907673,
+ "model.base_model.layers9.attn.in_proj.query": 0.9999695355072618,
+ "model.base_model.layers9.attn.in_proj.key": 0.9999892776831985,
+ "model.base_model.layers9.attn.in_proj.value": 0.9999962784349918,
+ "model.base_model.layers9.attn.in_proj": 0.9999962784349918,
+ "model.base_model.layers9.attn.dense": 1.000051881186664,
+ "model.base_model.layers9.ff_ln": 1.0000225473195314,
+ "model.base_model.layers9.ff_sub_layer.wg": 1.0000174287706614,
+ "model.base_model.layers9.ff_sub_layer.a": 0.9999986365437508,
+ "model.base_model.layers9.ff_sub_layer.w1": 1.0000402759760618,
+ "model.base_model.layers9.ff_sub_layer.w2": 0.9999891892075539,
+ "model.base_model.layers9.ff_sub_layer": 0.9999891892075539,
+ "model.base_model.layers10.ln": 0.9999757930636406,
+ "model.base_model.layers10.attn.in_proj.query": 1.0000325981527567,
+ "model.base_model.layers10.attn.in_proj.key": 0.9999954178929329,
+ "model.base_model.layers10.attn.in_proj.value": 1.000049689784646,
+ "model.base_model.layers10.attn.in_proj": 1.000049689784646,
+ "model.base_model.layers10.attn.dense": 0.9999648351222277,
+ "model.base_model.layers10.ff_ln": 1.0000237496569753,
+ "model.base_model.layers10.ff_sub_layer.wg": 0.999984978698194,
+ "model.base_model.layers10.ff_sub_layer.a": 1.0000277813524008,
+ "model.base_model.layers10.ff_sub_layer.w1": 0.9999703979119658,
+ "model.base_model.layers10.ff_sub_layer.w2": 1.000018141232431,
+ "model.base_model.layers10.ff_sub_layer": 1.000018141232431,
+ "model.base_model.layers11.ln": 0.9999837828800082,
+ "model.base_model.layers11.attn.in_proj.query": 0.9999337792396545,
+ "model.base_model.layers11.attn.in_proj.key": 1.0000134026631713,
+ "model.base_model.layers11.attn.in_proj.value": 0.9999888250604272,
+ "model.base_model.layers11.attn.in_proj": 0.9999888250604272,
+ "model.base_model.layers11.attn.dense": 0.9999930802732706,
+ "model.base_model.layers11.ff_ln": 1.0000165477395058,
+ "model.base_model.layers11.ff_sub_layer.wg": 0.999998620711267,
+ "model.base_model.layers11.ff_sub_layer.a": 1.000002202577889,
+ "model.base_model.layers11.ff_sub_layer.w1": 0.9999753125011921,
+ "model.base_model.layers11.ff_sub_layer.w2": 0.9999937638640404,
+ "model.base_model.layers11.ff_sub_layer": 0.9999937638640404,
+ "model.base_model.layers12.ln": 0.9999772552400827,
+ "model.base_model.layers12.attn.in_proj.query": 0.9999653361737728,
+ "model.base_model.layers12.attn.in_proj.key": 0.9999616499990225,
+ "model.base_model.layers12.attn.in_proj.value": 0.9999822629615664,
+ "model.base_model.layers12.attn.in_proj": 0.9999822629615664,
+ "model.base_model.layers12.attn.dense": 0.9999957624822855,
+ "model.base_model.layers12.ff_ln": 0.9999648667871952,
+ "model.base_model.layers12.ff_sub_layer.wg": 1.0000127339735627,
+ "model.base_model.layers12.ff_sub_layer.a": 0.9999869000166655,
+ "model.base_model.layers12.ff_sub_layer.w1": 1.0000022910535336,
+ "model.base_model.layers12.ff_sub_layer.w2": 1.0000239154323936,
+ "model.base_model.layers12.ff_sub_layer": 1.0000239154323936,
+ "model.base_model.layers13.ln": 0.9999950844794512,
+ "model.base_model.layers13.attn.in_proj.query": 0.999962380155921,
+ "model.base_model.layers13.attn.in_proj.key": 1.000015706755221,
+ "model.base_model.layers13.attn.in_proj.value": 1.000016774982214,
+ "model.base_model.layers13.attn.in_proj": 1.000016774982214,
+ "model.base_model.layers13.attn.dense": 0.9999829614534974,
+ "model.base_model.layers13.ff_ln": 0.999973920173943,
+ "model.base_model.layers13.ff_sub_layer.wg": 0.9999783085659146,
+ "model.base_model.layers13.ff_sub_layer.a": 1.0000212853774428,
+ "model.base_model.layers13.ff_sub_layer.w1": 1.0000321120023727,
+ "model.base_model.layers13.ff_sub_layer.w2": 0.9999998137354851,
+ "model.base_model.layers13.ff_sub_layer": 0.9999998137354851,
+ "model.base_model.layers14.ln": 1.000026528723538,
+ "model.base_model.layers14.attn.in_proj.query": 0.9999550497159362,
+ "model.base_model.layers14.attn.in_proj.key": 1.0000005215406418,
+ "model.base_model.layers14.attn.in_proj.value": 1.0000283233821392,
+ "model.base_model.layers14.attn.in_proj": 1.0000283233821392,
+ "model.base_model.layers14.attn.dense": 0.9999569887295365,
+ "model.base_model.layers14.ff_ln": 1.0000332901254296,
+ "model.base_model.layers14.ff_sub_layer.wg": 1.000024825334549,
+ "model.base_model.layers14.ff_sub_layer.a": 0.9999833712354302,
+ "model.base_model.layers14.ff_sub_layer.w1": 1.000031827017665,
+ "model.base_model.layers14.ff_sub_layer.w2": 0.9999874485656619,
+ "model.base_model.layers14.ff_sub_layer": 0.9999874485656619,
+ "model.base_model.layers15.ln": 0.9999578176066279,
+ "model.base_model.layers15.attn.in_proj.query": 0.9999970393255353,
+ "model.base_model.layers15.attn.in_proj.key": 0.9999960018321872,
+ "model.base_model.layers15.attn.in_proj.value": 0.9999749977141619,
+ "model.base_model.layers15.attn.in_proj": 0.9999749977141619,
+ "model.base_model.layers15.attn.dense": 1.0000038463622332,
+ "model.base_model.layers15.ff_ln": 0.9999995436519384,
+ "model.base_model.layers15.ff_sub_layer.wg": 1.000016150996089,
+ "model.base_model.layers15.ff_sub_layer.a": 0.9999963166192174,
+ "model.base_model.layers15.ff_sub_layer.w1": 0.9999893130734563,
+ "model.base_model.layers15.ff_sub_layer.w2": 1.0000572875142097,
+ "model.base_model.layers15.ff_sub_layer": 1.0000572875142097,
+ "model.base_model.layers16.ln": 0.9999966016039252,
+ "model.base_model.layers16.attn.in_proj.query": 1.0000449502840638,
+ "model.base_model.layers16.attn.in_proj.key": 0.9999969452619553,
+ "model.base_model.layers16.attn.in_proj.value": 1.000001596286893,
+ "model.base_model.layers16.attn.in_proj": 1.000001596286893,
+ "model.base_model.layers16.attn.dense": 0.9999923221766949,
+ "model.base_model.layers16.ff_ln": 1.000006434507668,
+ "model.base_model.layers16.ff_sub_layer.wg": 0.9999946439638734,
+ "model.base_model.layers16.ff_sub_layer.a": 0.9999985871836543,
+ "model.base_model.layers16.ff_sub_layer.w1": 1.000020869076252,
+ "model.base_model.layers16.ff_sub_layer.w2": 1.000023491680622,
+ "model.base_model.layers16.ff_sub_layer": 1.000023491680622,
+ "model.base_model.layers17.ln": 1.0000076917931437,
+ "model.base_model.layers17.attn.in_proj.query": 1.0000273855403066,
+ "model.base_model.layers17.attn.in_proj.key": 0.999996873550117,
+ "model.base_model.layers17.attn.in_proj.value": 1.0000212965533137,
+ "model.base_model.layers17.attn.in_proj": 1.0000212965533137,
+ "model.base_model.layers17.attn.dense": 1.000014752149582,
+ "model.base_model.layers17.ff_ln": 0.9999722754582763,
+ "model.base_model.layers17.ff_sub_layer.wg": 0.9999793535098433,
+ "model.base_model.layers17.ff_sub_layer.a": 0.9999643014743924,
+ "model.base_model.layers17.ff_sub_layer.w1": 0.999996374361217,
+ "model.base_model.layers17.ff_sub_layer.w2": 1.0000149980187416,
+ "model.base_model.layers17.ff_sub_layer": 1.0000149980187416,
+ "model.base_model.layers18.ln": 1.0000415071845055,
+ "model.base_model.layers18.attn.in_proj.query": 1.0000423351302743,
+ "model.base_model.layers18.attn.in_proj.key": 1.000030504539609,
+ "model.base_model.layers18.attn.in_proj.value": 0.9999837735667825,
+ "model.base_model.layers18.attn.in_proj": 0.9999837735667825,
+ "model.base_model.layers18.attn.dense": 1.000011927448213,
+ "model.base_model.layers18.ff_ln": 1.0000248439610004,
+ "model.base_model.layers18.ff_sub_layer.wg": 1.0000116229057312,
+ "model.base_model.layers18.ff_sub_layer.a": 1.000015627592802,
+ "model.base_model.layers18.ff_sub_layer.w1": 1.0000022556632757,
+ "model.base_model.layers18.ff_sub_layer.w2": 1.000029387883842,
+ "model.base_model.layers18.ff_sub_layer": 1.000029387883842,
+ "model.base_model.layers19.ln": 1.0000203093513846,
+ "model.base_model.layers19.attn.in_proj.query": 0.9999857172369957,
+ "model.base_model.layers19.attn.in_proj.key": 1.0000090887770057,
+ "model.base_model.layers19.attn.in_proj.value": 1.0000304272398353,
+ "model.base_model.layers19.attn.in_proj": 1.0000304272398353,
+ "model.base_model.layers19.attn.dense": 1.0000045774504542,
+ "model.base_model.layers19.ff_ln": 0.9999766666442156,
+ "model.base_model.layers19.ff_sub_layer.wg": 0.9999879337847233,
+ "model.base_model.layers19.ff_sub_layer.a": 1.0000366354361176,
+ "model.base_model.layers19.ff_sub_layer.w1": 1.0000044731423259,
+ "model.base_model.layers19.ff_sub_layer.w2": 0.999997922219336,
+ "model.base_model.layers19.ff_sub_layer": 0.999997922219336,
+ "model.base_model.layers20.ln": 0.9999682083725929,
+ "model.base_model.layers20.attn.in_proj.query": 0,
+ "model.base_model.layers20.attn.in_proj.key": 0,
+ "model.base_model.layers20.attn.in_proj.value": 0.9999906904995441,
+ "model.base_model.layers20.attn.in_proj": 0.9999906904995441,
+ "model.base_model.layers20.attn.dense": 1.0000156918540597,
+ "model.base_model.layers20.ff_ln": 0.9999907249584794,
+ "model.base_model.layers20.ff_sub_layer.wg": 0.9999940283596516,
+ "model.base_model.layers20.ff_sub_layer.a": 1.0000511296093464,
+ "model.base_model.layers20.ff_sub_layer.w1": 0.9999824520200491,
+ "model.base_model.layers20.ff_sub_layer.w2": 1.0000235689803958,
+ "model.base_model.layers20.ff_sub_layer": 1.0000235689803958,
+ "model.base_model.layers21.ln": 1.0000102324411273,
+ "model.base_model.layers21.attn.in_proj.query": 0.9999887580052018,
+ "model.base_model.layers21.attn.in_proj.key": 0.9999670209363103,
+ "model.base_model.layers21.attn.in_proj.value": 1.0000033304095268,
+ "model.base_model.layers21.attn.in_proj": 1.0000033304095268,
+ "model.base_model.layers21.attn.dense": 1.0000032112002373,
+ "model.base_model.layers21.ff_ln": 0.9999745050445199,
+ "model.base_model.layers21.ff_sub_layer.wg": 1.0000191712751985,
+ "model.base_model.layers21.ff_sub_layer.a": 0.9999397536739707,
+ "model.base_model.layers21.ff_sub_layer.w1": 0.999976921826601,
+ "model.base_model.layers21.ff_sub_layer.w2": 1.0000147921964526,
+ "model.base_model.layers21.ff_sub_layer": 1.0000147921964526,
+ "model.base_model.layers22.ln": 1.0000217780470848,
+ "model.base_model.layers22.attn.in_proj.query": 1.0000304505228996,
+ "model.base_model.layers22.attn.in_proj.key": 1.0000391714274883,
+ "model.base_model.layers22.attn.in_proj.value": 0.9999985108152032,
+ "model.base_model.layers22.attn.in_proj": 0.9999985108152032,
+ "model.base_model.layers22.attn.dense": 0.9999870825558901,
+ "model.base_model.layers22.ff_ln": 1.000030828639865,
+ "model.base_model.layers22.ff_sub_layer.wg": 0.9999970505014062,
+ "model.base_model.layers22.ff_sub_layer.a": 0.9999665170907974,
+ "model.base_model.layers22.ff_sub_layer.w1": 1.0000078286975622,
+ "model.base_model.layers22.ff_sub_layer.w2": 0.9999847915023565,
+ "model.base_model.layers22.ff_sub_layer": 0.9999847915023565,
+ "model.base_model.layers23.ln": 1.000011671334505,
+ "model.base_model.layers23.attn.in_proj.query": 0,
+ "model.base_model.layers23.attn.in_proj.key": 0,
+ "model.base_model.layers23.attn.in_proj.value": 1.0000066049396992,
+ "model.base_model.layers23.attn.in_proj": 1.0000066049396992,
+ "model.base_model.layers23.attn.dense": 1.000013079494238,
+ "model.base_model.layers23.ff_ln": 0.9999612653627992,
+ "model.base_model.layers23.ff_sub_layer.wg": 1.000030379742384,
+ "model.base_model.layers23.ff_sub_layer.a": 1.0000113025307655,
+ "model.base_model.layers23.ff_sub_layer.w1": 1.0000112522393465,
+ "model.base_model.layers23.ff_sub_layer.w2": 0.9999910769984126,
+ "model.base_model.layers23.ff_sub_layer": 0.9999910769984126,
+ "model.base_model.layers24.ln": 0.9999892562627792,
+ "model.base_model.layers24.attn.in_proj.query": 0,
+ "model.base_model.layers24.attn.in_proj.key": 0,
+ "model.base_model.layers24.attn.in_proj.value": 1.0000211102887988,
+ "model.base_model.layers24.attn.in_proj": 1.0000211102887988,
+ "model.base_model.layers24.attn.dense": 1.0000214129686356,
+ "model.base_model.layers24.ff_ln": 1.000025992281735,
+ "model.base_model.layers24.ff_sub_layer.wg": 1.00003114156425,
+ "model.base_model.layers24.ff_sub_layer.a": 0.999997915700078,
+ "model.base_model.layers24.ff_sub_layer.w1": 0.9999974723905325,
+ "model.base_model.layers24.ff_sub_layer.w2": 0.9999948246404529,
+ "model.base_model.layers24.ff_sub_layer": 0.9999948246404529,
+ "model.base_model.layers25.ln": 1.0000123046338558,
+ "model.base_model.layers25.attn.in_proj.query": 1.000020838342607,
+ "model.base_model.layers25.attn.in_proj.key": 1.000002938322723,
+ "model.base_model.layers25.attn.in_proj.value": 1.000010005198419,
+ "model.base_model.layers25.attn.in_proj": 1.000010005198419,
+ "model.base_model.layers25.attn.dense": 1.0000402312725782,
+ "model.base_model.layers25.ff_ln": 0.9999652067199349,
+ "model.base_model.layers25.ff_sub_layer.wg": 0.9999842857941985,
+ "model.base_model.layers25.ff_sub_layer.a": 1.0000258618965745,
+ "model.base_model.layers25.ff_sub_layer.w1": 1.0000001015141606,
+ "model.base_model.layers25.ff_sub_layer.w2": 0.9999574366956949,
+ "model.base_model.layers25.ff_sub_layer": 0.9999574366956949,
+ "model.base_model.layers26.ln": 0.9999806303530931,
+ "model.base_model.layers26.attn.in_proj.query": 0,
+ "model.base_model.layers26.attn.in_proj.key": 0,
+ "model.base_model.layers26.attn.in_proj.value": 0.9999995995312929,
+ "model.base_model.layers26.attn.in_proj": 0.9999995995312929,
+ "model.base_model.layers26.attn.dense": 1.000026204623282,
+ "model.base_model.layers26.ff_ln": 1.0000002216547728,
+ "model.base_model.layers26.ff_sub_layer.wg": 1.000026123598218,
+ "model.base_model.layers26.ff_sub_layer.a": 1.0000658445060253,
+ "model.base_model.layers26.ff_sub_layer.w1": 0.9999972693622112,
+ "model.base_model.layers26.ff_sub_layer.w2": 0.999981164932251,
+ "model.base_model.layers26.ff_sub_layer": 0.999981164932251,
+ "model.base_model.layers27.ln": 0.9999749213457108,
+ "model.base_model.layers27.attn.in_proj.query": 0,
+ "model.base_model.layers27.attn.in_proj.key": 0,
+ "model.base_model.layers27.attn.in_proj.value": 0.99999382160604,
+ "model.base_model.layers27.attn.in_proj": 0.99999382160604,
+ "model.base_model.layers27.attn.dense": 1.000030618160963,
+ "model.base_model.layers27.ff_ln": 1.0000327732414007,
+ "model.base_model.layers27.ff_sub_layer.wg": 1.0000053877010942,
+ "model.base_model.layers27.ff_sub_layer.a": 0.9999731816351414,
+ "model.base_model.layers27.ff_sub_layer.w1": 0.9999862620607018,
+ "model.base_model.layers27.ff_sub_layer.w2": 1.000004656612873,
+ "model.base_model.layers27.ff_sub_layer": 1.000004656612873,
+ "model.base_model.layers28.ln": 1.0000333590433002,
+ "model.base_model.layers28.attn.in_proj.query": 0.9999866727739573,
+ "model.base_model.layers28.attn.in_proj.key": 0.9999993778765202,
+ "model.base_model.layers28.attn.in_proj.value": 0.9999468578025699,
+ "model.base_model.layers28.attn.in_proj": 0.9999468578025699,
+ "model.base_model.layers28.attn.dense": 1.0000711474567652,
+ "model.base_model.layers28.ff_ln": 1.0000074906274676,
+ "model.base_model.layers28.ff_sub_layer.wg": 0.9999890364706516,
+ "model.base_model.layers28.ff_sub_layer.a": 0.9999930933117867,
+ "model.base_model.layers28.ff_sub_layer.w1": 0.99996502045542,
+ "model.base_model.layers28.ff_sub_layer.w2": 0.9999485462903976,
+ "model.base_model.layers28.ff_sub_layer": 0.9999485462903976,
+ "model.base_model.layers29.ln": 0.9999927263706923,
+ "model.base_model.layers29.attn.in_proj.query": 0,
+ "model.base_model.layers29.attn.in_proj.key": 0,
+ "model.base_model.layers29.attn.in_proj.value": 0.999944906681776,
+ "model.base_model.layers29.attn.in_proj": 0.999944906681776,
+ "model.base_model.layers29.attn.dense": 0.9999553980305791,
+ "model.base_model.layers29.ff_ln": 0.9999844022095203,
+ "model.base_model.layers29.ff_sub_layer.wg": 1.0000193798914552,
+ "model.base_model.layers29.ff_sub_layer.a": 0.999941049143672,
+ "model.base_model.layers29.ff_sub_layer.w1": 1.000042955391109,
+ "model.base_model.layers29.ff_sub_layer.w2": 0.9999570790678263,
+ "model.base_model.layers29.ff_sub_layer": 0.9999570790678263,
+ "model.base_model.layers30.ln": 0.9999837139621377,
+ "model.base_model.layers30.attn.in_proj.query": 0.9999870667234063,
+ "model.base_model.layers30.attn.in_proj.key": 0.9999911524355412,
+ "model.base_model.layers30.attn.in_proj.value": 1.0000267829746008,
+ "model.base_model.layers30.attn.in_proj": 1.0000267829746008,
+ "model.base_model.layers30.attn.dense": 1.000015159137547,
+ "model.base_model.layers30.ff_ln": 0.9999454841017723,
+ "model.base_model.layers30.ff_sub_layer.wg": 1.0000337418168783,
+ "model.base_model.layers30.ff_sub_layer.a": 1.0000011585652828,
+ "model.base_model.layers30.ff_sub_layer.w1": 0.9999842243269086,
+ "model.base_model.layers30.ff_sub_layer.w2": 1.0000219950452447,
+ "model.base_model.layers30.ff_sub_layer": 1.0000219950452447,
+ "model.base_model.layers31.ln": 1.0000553196296096,
+ "model.base_model.layers31.attn.in_proj.query": 1.0000495202839375,
+ "model.base_model.layers31.attn.in_proj.key": 0.9999600946903229,
+ "model.base_model.layers31.attn.in_proj.value": 0.9999898271635175,
+ "model.base_model.layers31.attn.in_proj": 0.9999898271635175,
+ "model.base_model.layers31.attn.dense": 1.0000101793557405,
+ "model.base_model.layers31.ff_ln": 0.9999493975192308,
+ "model.base_model.layers31.ff_sub_layer.wg": 1.0000120643526316,
+ "model.base_model.layers31.ff_sub_layer.a": 0.9999757166951895,
+ "model.base_model.layers31.ff_sub_layer.w1": 0.9999987371265888,
+ "model.base_model.layers31.ff_sub_layer.w2": 1.0000365702435374,
+ "model.base_model.layers31.ff_sub_layer": 1.0000365702435374,
+ "model.base_model.layers32.ln": 0.9999954663217068,
+ "model.base_model.layers32.attn.in_proj.query": 0.9999834299087524,
+ "model.base_model.layers32.attn.in_proj.key": 0.9999754223972559,
+ "model.base_model.layers32.attn.in_proj.value": 1.0000332547351718,
+ "model.base_model.layers32.attn.in_proj": 1.0000332547351718,
+ "model.base_model.layers32.attn.dense": 1.0000128587707877,
+ "model.base_model.layers32.ff_ln": 0.9999977294355631,
+ "model.base_model.layers32.ff_sub_layer.wg": 1.000013922341168,
+ "model.base_model.layers32.ff_sub_layer.a": 0.9999843658879399,
+ "model.base_model.layers32.ff_sub_layer.w1": 1.00000412017107,
+ "model.base_model.layers32.ff_sub_layer.w2": 1.0000331467017531,
+ "model.base_model.layers32.ff_sub_layer": 1.0000331467017531,
+ "model.base_model.layers33.ln": 0.9999776463955641,
+ "model.base_model.layers33.attn.in_proj.query": 0.9999711429700255,
+ "model.base_model.layers33.attn.in_proj.key": 1.000004799105227,
+ "model.base_model.layers33.attn.in_proj.value": 1.000003564171493,
+ "model.base_model.layers33.attn.in_proj": 1.000003564171493,
+ "model.base_model.layers33.attn.dense": 1.0000118184834719,
+ "model.base_model.layers33.ff_ln": 0.9999954793602228,
+ "model.base_model.layers33.ff_sub_layer.wg": 1.000005860812962,
+ "model.base_model.layers33.ff_sub_layer.a": 1.0000032559037209,
+ "model.base_model.layers33.ff_sub_layer.w1": 1.0000015143305063,
+ "model.base_model.layers33.ff_sub_layer.w2": 0.9999763583764434,
+ "model.base_model.layers33.ff_sub_layer": 0.9999763583764434,
+ "model.base_model.layers34.ln": 0.9999581193551421,
+ "model.base_model.layers34.attn.in_proj.query": 1.000028538517654,
+ "model.base_model.layers34.attn.in_proj.key": 0.9999862778931856,
+ "model.base_model.layers34.attn.in_proj.value": 0.9999695559963584,
+ "model.base_model.layers34.attn.in_proj": 0.9999695559963584,
+ "model.base_model.layers34.attn.dense": 1.0000354330986738,
+ "model.base_model.layers34.ff_ln": 0.9999958276748657,
+ "model.base_model.layers34.ff_sub_layer.wg": 0.9999709380790591,
+ "model.base_model.layers34.ff_sub_layer.a": 0.9999592872336507,
+ "model.base_model.layers34.ff_sub_layer.w1": 0.9999475320801139,
+ "model.base_model.layers34.ff_sub_layer.w2": 1.0000180834904313,
+ "model.base_model.layers34.ff_sub_layer": 1.0000180834904313,
+ "model.base_model.layers35.ln": 1.0000238129869103,
+ "model.base_model.layers35.attn.in_proj.query": 0.9999912530183792,
+ "model.base_model.layers35.attn.in_proj.key": 0.9999760650098324,
+ "model.base_model.layers35.attn.in_proj.value": 1.0000254157930613,
+ "model.base_model.layers35.attn.in_proj": 1.0000254157930613,
+ "model.base_model.layers35.attn.dense": 1.000043666921556,
+ "model.base_model.layers35.ff_ln": 0.9999769097194076,
+ "model.base_model.layers35.ff_sub_layer.wg": 1.000048492103815,
+ "model.base_model.layers35.ff_sub_layer.a": 1.000031128525734,
+ "model.base_model.layers35.ff_sub_layer.w1": 1.0000278418883681,
+ "model.base_model.layers35.ff_sub_layer.w2": 0.9999823272228241,
+ "model.base_model.layers35.ff_sub_layer": 0.9999823272228241,
+ "model.base_model.layers36.ln": 1.0000083968043327,
+ "model.base_model.layers36.attn.in_proj.query": 0.9999890839681029,
+ "model.base_model.layers36.attn.in_proj.key": 1.0000366857275367,
+ "model.base_model.layers36.attn.in_proj.value": 0.9999926909804344,
+ "model.base_model.layers36.attn.in_proj": 0.9999926909804344,
+ "model.base_model.layers36.attn.dense": 0.9999915473163128,
+ "model.base_model.layers36.ff_ln": 1.0000037085264921,
+ "model.base_model.layers36.ff_sub_layer.wg": 0.9999989084899426,
+ "model.base_model.layers36.ff_sub_layer.a": 1.0000268882140517,
+ "model.base_model.layers36.ff_sub_layer.w1": 1.0000533871352673,
+ "model.base_model.layers36.ff_sub_layer.w2": 1.0000037532299757,
+ "model.base_model.layers36.ff_sub_layer": 1.0000037532299757,
+ "model.base_model.layers37.ln": 0.9999704547226429,
+ "model.base_model.layers37.attn.in_proj.query": 0.9999823262915015,
+ "model.base_model.layers37.attn.in_proj.key": 0.9999811919406056,
+ "model.base_model.layers37.attn.in_proj.value": 0.9999743653461337,
+ "model.base_model.layers37.attn.in_proj": 0.9999743653461337,
+ "model.base_model.layers37.attn.dense": 1.0000066151842475,
+ "model.base_model.layers37.ff_ln": 0.9999751346185803,
+ "model.base_model.layers37.ff_sub_layer.wg": 0.999997915700078,
+ "model.base_model.layers37.ff_sub_layer.a": 1.0000371094793081,
+ "model.base_model.layers37.ff_sub_layer.w1": 1.0000042142346501,
+ "model.base_model.layers37.ff_sub_layer.w2": 1.000011881813407,
+ "model.base_model.layers37.ff_sub_layer": 1.000011881813407,
+ "model.base_model.layers38.ln": 1.0000276556238532,
+ "model.base_model.layers38.attn.in_proj.query": 0.9999751839786768,
+ "model.base_model.layers38.attn.in_proj.key": 0.9999913470819592,
+ "model.base_model.layers38.attn.in_proj.value": 0.9999802326783538,
+ "model.base_model.layers38.attn.in_proj": 0.9999802326783538,
+ "model.base_model.layers38.attn.dense": 1.0000138999894261,
+ "model.base_model.layers38.ff_ln": 0.9999972302466631,
+ "model.base_model.layers38.ff_sub_layer.wg": 0.9999596020206809,
+ "model.base_model.layers38.ff_sub_layer.a": 0.9999620048329234,
+ "model.base_model.layers38.ff_sub_layer.w1": 1.0000002244487405,
+ "model.base_model.layers38.ff_sub_layer.w2": 1.0000019194558263,
+ "model.base_model.layers38.ff_sub_layer": 1.0000019194558263,
+ "model.base_model.layers39.ln": 0.9999960288405418,
+ "model.base_model.layers39.attn.in_proj.query": 0.9999866196885705,
+ "model.base_model.layers39.attn.in_proj.key": 1.000024433247745,
+ "model.base_model.layers39.attn.in_proj.value": 0.9999685110524297,
+ "model.base_model.layers39.attn.in_proj": 0.9999685110524297,
+ "model.base_model.layers39.attn.dense": 0.9999954961240292,
+ "model.base_model.layers39.ff_ln": 1.0000354265794158,
+ "model.base_model.layers39.ff_sub_layer.wg": 1.0000474276021123,
+ "model.base_model.layers39.ff_sub_layer.a": 1.0000188555568457,
+ "model.base_model.layers39.ff_sub_layer.w1": 0.9999263407662511,
+ "model.base_model.layers39.ff_sub_layer.w2": 0.999988155439496,
+ "model.base_model.layers39.ff_sub_layer": 0.999988155439496,
+ "model.base_model.dec_norm": 1.0000296486541629,
+ "model.base_model": 0,
+ "model.head": 1.0000323532149196
+ },
+ "cos_sim_mean": {
+ "model.base_model.embedding": 1.0000095833092928,
+ "model.base_model.layers0.ln": 1.0000184457749128,
+ "model.base_model.layers0.attn.in_proj.query": 0.9999569654464722,
+ "model.base_model.layers0.attn.in_proj.key": 1.000030318275094,
+ "model.base_model.layers0.attn.in_proj.value": 0.9999886471778154,
+ "model.base_model.layers0.attn.in_proj": 0.9999886471778154,
+ "model.base_model.layers0.attn.dense": 1.0000049602240324,
+ "model.base_model.layers0.ff_ln": 0.9999961135908961,
+ "model.base_model.layers0.ff_sub_layer.wg": 1.0000046007335186,
+ "model.base_model.layers0.ff_sub_layer.a": 1.0000147661194205,
+ "model.base_model.layers0.ff_sub_layer.w1": 1.0000530388206244,
+ "model.base_model.layers0.ff_sub_layer.w2": 0.9999505197629333,
+ "model.base_model.layers0.ff_sub_layer": 0.9999505197629333,
+ "model.base_model.layers1.ln": 0.9999880297109485,
+ "model.base_model.layers1.attn.in_proj.query": 0.9999853884801269,
+ "model.base_model.layers1.attn.in_proj.key": 0.9999964172020555,
+ "model.base_model.layers1.attn.in_proj.value": 0.999986432492733,
+ "model.base_model.layers1.attn.in_proj": 0.999986432492733,
+ "model.base_model.layers1.attn.dense": 0.9999998668208718,
+ "model.base_model.layers1.ff_ln": 1.0000135749578476,
+ "model.base_model.layers1.ff_sub_layer.wg": 1.0000185761600733,
+ "model.base_model.layers1.ff_sub_layer.a": 0.999986195936799,
+ "model.base_model.layers1.ff_sub_layer.w1": 1.000002909451723,
+ "model.base_model.layers1.ff_sub_layer.w2": 1.0000054109841585,
+ "model.base_model.layers1.ff_sub_layer": 1.0000054109841585,
+ "model.base_model.layers2.ln": 1.0000562984496355,
+ "model.base_model.layers2.attn.in_proj.query": 0.9999984446913004,
+ "model.base_model.layers2.attn.in_proj.key": 1.0000433661043644,
+ "model.base_model.layers2.attn.in_proj.value": 0.9999881014227867,
+ "model.base_model.layers2.attn.in_proj": 0.9999881014227867,
+ "model.base_model.layers2.attn.dense": 1.000006778165698,
+ "model.base_model.layers2.ff_ln": 1.0000063749030232,
+ "model.base_model.layers2.ff_sub_layer.wg": 1.0000359807163477,
+ "model.base_model.layers2.ff_sub_layer.a": 0.999996374361217,
+ "model.base_model.layers2.ff_sub_layer.w1": 0.9999919822439551,
+ "model.base_model.layers2.ff_sub_layer.w2": 1.0000033108517528,
+ "model.base_model.layers2.ff_sub_layer": 1.0000033108517528,
+ "model.base_model.layers3.ln": 1.0000399872660637,
+ "model.base_model.layers3.attn.in_proj.query": 0.9999705422669649,
+ "model.base_model.layers3.attn.in_proj.key": 1.0000282796099782,
+ "model.base_model.layers3.attn.in_proj.value": 0.9999693483114243,
+ "model.base_model.layers3.attn.in_proj": 0.9999693483114243,
+ "model.base_model.layers3.attn.dense": 1.0000234134495258,
+ "model.base_model.layers3.ff_ln": 0.9999883593991399,
+ "model.base_model.layers3.ff_sub_layer.wg": 0.9999996954575181,
+ "model.base_model.layers3.ff_sub_layer.a": 1.0000347420573235,
+ "model.base_model.layers3.ff_sub_layer.w1": 1.0000090897083282,
+ "model.base_model.layers3.ff_sub_layer.w2": 0.99996093288064,
+ "model.base_model.layers3.ff_sub_layer": 0.99996093288064,
+ "model.base_model.layers4.ln": 1.0000464990735054,
+ "model.base_model.layers4.attn.in_proj.query": 1.000026903115213,
+ "model.base_model.layers4.attn.in_proj.key": 0.999997797422111,
+ "model.base_model.layers4.attn.in_proj.value": 0.9999980982393026,
+ "model.base_model.layers4.attn.in_proj": 0.9999980982393026,
+ "model.base_model.layers4.attn.dense": 1.0000115931034088,
+ "model.base_model.layers4.ff_ln": 1.0000024493783712,
+ "model.base_model.layers4.ff_sub_layer.wg": 1.0000334102660418,
+ "model.base_model.layers4.ff_sub_layer.a": 1.0000202497467399,
+ "model.base_model.layers4.ff_sub_layer.w1": 1.0000194078311324,
+ "model.base_model.layers4.ff_sub_layer.w2": 1.0000301413238049,
+ "model.base_model.layers4.ff_sub_layer": 1.0000301413238049,
+ "model.base_model.layers5.ln": 0.999975698068738,
+ "model.base_model.layers5.attn.in_proj.query": 1.0000565703958273,
+ "model.base_model.layers5.attn.in_proj.key": 1.0000175070017576,
+ "model.base_model.layers5.attn.in_proj.value": 1.0000161584466696,
+ "model.base_model.layers5.attn.in_proj": 1.0000161584466696,
+ "model.base_model.layers5.attn.dense": 1.000037212856114,
+ "model.base_model.layers5.ff_ln": 0.9999989373609424,
+ "model.base_model.layers5.ff_sub_layer.wg": 0.9999949997290969,
+ "model.base_model.layers5.ff_sub_layer.a": 1.000015702098608,
+ "model.base_model.layers5.ff_sub_layer.w1": 0.9999757846817374,
+ "model.base_model.layers5.ff_sub_layer.w2": 1.0000016959384084,
+ "model.base_model.layers5.ff_sub_layer": 1.0000016959384084,
+ "model.base_model.layers6.ln": 1.0000160429626703,
+ "model.base_model.layers6.attn.in_proj.query": 0.9999967850744724,
+ "model.base_model.layers6.attn.in_proj.key": 0.9999998025596142,
+ "model.base_model.layers6.attn.in_proj.value": 1.0000286241993308,
+ "model.base_model.layers6.attn.in_proj": 1.0000286241993308,
+ "model.base_model.layers6.attn.dense": 1.0000265408307314,
+ "model.base_model.layers6.ff_ln": 1.0000047851353884,
+ "model.base_model.layers6.ff_sub_layer.wg": 1.0000188006088138,
+ "model.base_model.layers6.ff_sub_layer.a": 0.9999990398064256,
+ "model.base_model.layers6.ff_sub_layer.w1": 0.9999480145052075,
+ "model.base_model.layers6.ff_sub_layer.w2": 1.0000073416158557,
+ "model.base_model.layers6.ff_sub_layer": 1.0000073416158557,
+ "model.base_model.layers7.ln": 1.0000189878046513,
+ "model.base_model.layers7.attn.in_proj.query": 1.0000445498153567,
+ "model.base_model.layers7.attn.in_proj.key": 0.9999774135649204,
+ "model.base_model.layers7.attn.in_proj.value": 1.0000242739915848,
+ "model.base_model.layers7.attn.in_proj": 1.0000242739915848,
+ "model.base_model.layers7.attn.dense": 0.9999957242980599,
+ "model.base_model.layers7.ff_ln": 0.999947358854115,
+ "model.base_model.layers7.ff_sub_layer.wg": 0.999967728741467,
+ "model.base_model.layers7.ff_sub_layer.a": 1.000053628347814,
+ "model.base_model.layers7.ff_sub_layer.w1": 1.0000687642022967,
+ "model.base_model.layers7.ff_sub_layer.w2": 0.9999854220077395,
+ "model.base_model.layers7.ff_sub_layer": 0.9999854220077395,
+ "model.base_model.layers8.ln": 1.000000380910933,
+ "model.base_model.layers8.attn.in_proj.query": 0.9999749204143882,
+ "model.base_model.layers8.attn.in_proj.key": 0.9999820357188582,
+ "model.base_model.layers8.attn.in_proj.value": 0.9999982211738825,
+ "model.base_model.layers8.attn.in_proj": 0.9999982211738825,
+ "model.base_model.layers8.attn.dense": 1.0000164853408933,
+ "model.base_model.layers8.ff_ln": 0.9999812999740243,
+ "model.base_model.layers8.ff_sub_layer.wg": 0.9999558068811893,
+ "model.base_model.layers8.ff_sub_layer.a": 0.9999397126957774,
+ "model.base_model.layers8.ff_sub_layer.w1": 0.9999668747186661,
+ "model.base_model.layers8.ff_sub_layer.w2": 0.9999804692342877,
+ "model.base_model.layers8.ff_sub_layer": 0.9999804692342877,
+ "model.base_model.layers9.ln": 1.0000508911907673,
+ "model.base_model.layers9.attn.in_proj.query": 0.9999695355072618,
+ "model.base_model.layers9.attn.in_proj.key": 0.9999892776831985,
+ "model.base_model.layers9.attn.in_proj.value": 0.9999962784349918,
+ "model.base_model.layers9.attn.in_proj": 0.9999962784349918,
+ "model.base_model.layers9.attn.dense": 1.000051881186664,
+ "model.base_model.layers9.ff_ln": 1.0000225473195314,
+ "model.base_model.layers9.ff_sub_layer.wg": 1.0000174287706614,
+ "model.base_model.layers9.ff_sub_layer.a": 0.9999986365437508,
+ "model.base_model.layers9.ff_sub_layer.w1": 1.0000402759760618,
+ "model.base_model.layers9.ff_sub_layer.w2": 0.9999891892075539,
+ "model.base_model.layers9.ff_sub_layer": 0.9999891892075539,
+ "model.base_model.layers10.ln": 0.9999757930636406,
+ "model.base_model.layers10.attn.in_proj.query": 1.0000325981527567,
+ "model.base_model.layers10.attn.in_proj.key": 0.9999954178929329,
+ "model.base_model.layers10.attn.in_proj.value": 1.000049689784646,
+ "model.base_model.layers10.attn.in_proj": 1.000049689784646,
+ "model.base_model.layers10.attn.dense": 0.9999648351222277,
+ "model.base_model.layers10.ff_ln": 1.0000237496569753,
+ "model.base_model.layers10.ff_sub_layer.wg": 0.999984978698194,
+ "model.base_model.layers10.ff_sub_layer.a": 1.0000277813524008,
+ "model.base_model.layers10.ff_sub_layer.w1": 0.9999703979119658,
+ "model.base_model.layers10.ff_sub_layer.w2": 1.000018141232431,
+ "model.base_model.layers10.ff_sub_layer": 1.000018141232431,
+ "model.base_model.layers11.ln": 0.9999837828800082,
+ "model.base_model.layers11.attn.in_proj.query": 0.9999337792396545,
+ "model.base_model.layers11.attn.in_proj.key": 1.0000134026631713,
+ "model.base_model.layers11.attn.in_proj.value": 0.9999888250604272,
+ "model.base_model.layers11.attn.in_proj": 0.9999888250604272,
+ "model.base_model.layers11.attn.dense": 0.9999930802732706,
+ "model.base_model.layers11.ff_ln": 1.0000165477395058,
+ "model.base_model.layers11.ff_sub_layer.wg": 0.999998620711267,
+ "model.base_model.layers11.ff_sub_layer.a": 1.000002202577889,
+ "model.base_model.layers11.ff_sub_layer.w1": 0.9999753125011921,
+ "model.base_model.layers11.ff_sub_layer.w2": 0.9999937638640404,
+ "model.base_model.layers11.ff_sub_layer": 0.9999937638640404,
+ "model.base_model.layers12.ln": 0.9999772552400827,
+ "model.base_model.layers12.attn.in_proj.query": 0.9999653361737728,
+ "model.base_model.layers12.attn.in_proj.key": 0.9999616499990225,
+ "model.base_model.layers12.attn.in_proj.value": 0.9999822629615664,
+ "model.base_model.layers12.attn.in_proj": 0.9999822629615664,
+ "model.base_model.layers12.attn.dense": 0.9999957624822855,
+ "model.base_model.layers12.ff_ln": 0.9999648667871952,
+ "model.base_model.layers12.ff_sub_layer.wg": 1.0000127339735627,
+ "model.base_model.layers12.ff_sub_layer.a": 0.9999869000166655,
+ "model.base_model.layers12.ff_sub_layer.w1": 1.0000022910535336,
+ "model.base_model.layers12.ff_sub_layer.w2": 1.0000239154323936,
+ "model.base_model.layers12.ff_sub_layer": 1.0000239154323936,
+ "model.base_model.layers13.ln": 0.9999950844794512,
+ "model.base_model.layers13.attn.in_proj.query": 0.999962380155921,
+ "model.base_model.layers13.attn.in_proj.key": 1.000015706755221,
+ "model.base_model.layers13.attn.in_proj.value": 1.000016774982214,
+ "model.base_model.layers13.attn.in_proj": 1.000016774982214,
+ "model.base_model.layers13.attn.dense": 0.9999829614534974,
+ "model.base_model.layers13.ff_ln": 0.999973920173943,
+ "model.base_model.layers13.ff_sub_layer.wg": 0.9999783085659146,
+ "model.base_model.layers13.ff_sub_layer.a": 1.0000212853774428,
+ "model.base_model.layers13.ff_sub_layer.w1": 1.0000321120023727,
+ "model.base_model.layers13.ff_sub_layer.w2": 0.9999998137354851,
+ "model.base_model.layers13.ff_sub_layer": 0.9999998137354851,
+ "model.base_model.layers14.ln": 1.000026528723538,
+ "model.base_model.layers14.attn.in_proj.query": 0.9999550497159362,
+ "model.base_model.layers14.attn.in_proj.key": 1.0000005215406418,
+ "model.base_model.layers14.attn.in_proj.value": 1.0000283233821392,
+ "model.base_model.layers14.attn.in_proj": 1.0000283233821392,
+ "model.base_model.layers14.attn.dense": 0.9999569887295365,
+ "model.base_model.layers14.ff_ln": 1.0000332901254296,
+ "model.base_model.layers14.ff_sub_layer.wg": 1.000024825334549,
+ "model.base_model.layers14.ff_sub_layer.a": 0.9999833712354302,
+ "model.base_model.layers14.ff_sub_layer.w1": 1.000031827017665,
+ "model.base_model.layers14.ff_sub_layer.w2": 0.9999874485656619,
+ "model.base_model.layers14.ff_sub_layer": 0.9999874485656619,
+ "model.base_model.layers15.ln": 0.9999578176066279,
+ "model.base_model.layers15.attn.in_proj.query": 0.9999970393255353,
+ "model.base_model.layers15.attn.in_proj.key": 0.9999960018321872,
+ "model.base_model.layers15.attn.in_proj.value": 0.9999749977141619,
+ "model.base_model.layers15.attn.in_proj": 0.9999749977141619,
+ "model.base_model.layers15.attn.dense": 1.0000038463622332,
+ "model.base_model.layers15.ff_ln": 0.9999995436519384,
+ "model.base_model.layers15.ff_sub_layer.wg": 1.000016150996089,
+ "model.base_model.layers15.ff_sub_layer.a": 0.9999963166192174,
+ "model.base_model.layers15.ff_sub_layer.w1": 0.9999893130734563,
+ "model.base_model.layers15.ff_sub_layer.w2": 1.0000572875142097,
+ "model.base_model.layers15.ff_sub_layer": 1.0000572875142097,
+ "model.base_model.layers16.ln": 0.9999966016039252,
+ "model.base_model.layers16.attn.in_proj.query": 1.0000449502840638,
+ "model.base_model.layers16.attn.in_proj.key": 0.9999969452619553,
+ "model.base_model.layers16.attn.in_proj.value": 1.000001596286893,
+ "model.base_model.layers16.attn.in_proj": 1.000001596286893,
+ "model.base_model.layers16.attn.dense": 0.9999923221766949,
+ "model.base_model.layers16.ff_ln": 1.000006434507668,
+ "model.base_model.layers16.ff_sub_layer.wg": 0.9999946439638734,
+ "model.base_model.layers16.ff_sub_layer.a": 0.9999985871836543,
+ "model.base_model.layers16.ff_sub_layer.w1": 1.000020869076252,
+ "model.base_model.layers16.ff_sub_layer.w2": 1.000023491680622,
+ "model.base_model.layers16.ff_sub_layer": 1.000023491680622,
+ "model.base_model.layers17.ln": 1.0000076917931437,
+ "model.base_model.layers17.attn.in_proj.query": 1.0000273855403066,
+ "model.base_model.layers17.attn.in_proj.key": 0.999996873550117,
+ "model.base_model.layers17.attn.in_proj.value": 1.0000212965533137,
+ "model.base_model.layers17.attn.in_proj": 1.0000212965533137,
+ "model.base_model.layers17.attn.dense": 1.000014752149582,
+ "model.base_model.layers17.ff_ln": 0.9999722754582763,
+ "model.base_model.layers17.ff_sub_layer.wg": 0.9999793535098433,
+ "model.base_model.layers17.ff_sub_layer.a": 0.9999643014743924,
+ "model.base_model.layers17.ff_sub_layer.w1": 0.999996374361217,
+ "model.base_model.layers17.ff_sub_layer.w2": 1.0000149980187416,
+ "model.base_model.layers17.ff_sub_layer": 1.0000149980187416,
+ "model.base_model.layers18.ln": 1.0000415071845055,
+ "model.base_model.layers18.attn.in_proj.query": 1.0000423351302743,
+ "model.base_model.layers18.attn.in_proj.key": 1.000030504539609,
+ "model.base_model.layers18.attn.in_proj.value": 0.9999837735667825,
+ "model.base_model.layers18.attn.in_proj": 0.9999837735667825,
+ "model.base_model.layers18.attn.dense": 1.000011927448213,
+ "model.base_model.layers18.ff_ln": 1.0000248439610004,
+ "model.base_model.layers18.ff_sub_layer.wg": 1.0000116229057312,
+ "model.base_model.layers18.ff_sub_layer.a": 1.000015627592802,
+ "model.base_model.layers18.ff_sub_layer.w1": 1.0000022556632757,
+ "model.base_model.layers18.ff_sub_layer.w2": 1.000029387883842,
+ "model.base_model.layers18.ff_sub_layer": 1.000029387883842,
+ "model.base_model.layers19.ln": 1.0000203093513846,
+ "model.base_model.layers19.attn.in_proj.query": 0.9999857172369957,
+ "model.base_model.layers19.attn.in_proj.key": 1.0000090887770057,
+ "model.base_model.layers19.attn.in_proj.value": 1.0000304272398353,
+ "model.base_model.layers19.attn.in_proj": 1.0000304272398353,
+ "model.base_model.layers19.attn.dense": 1.0000045774504542,
+ "model.base_model.layers19.ff_ln": 0.9999766666442156,
+ "model.base_model.layers19.ff_sub_layer.wg": 0.9999879337847233,
+ "model.base_model.layers19.ff_sub_layer.a": 1.0000366354361176,
+ "model.base_model.layers19.ff_sub_layer.w1": 1.0000044731423259,
+ "model.base_model.layers19.ff_sub_layer.w2": 0.999997922219336,
+ "model.base_model.layers19.ff_sub_layer": 0.999997922219336,
+ "model.base_model.layers20.ln": 0.9999682083725929,
+ "model.base_model.layers20.attn.in_proj.query": 0,
+ "model.base_model.layers20.attn.in_proj.key": 0,
+ "model.base_model.layers20.attn.in_proj.value": 0.9999906904995441,
+ "model.base_model.layers20.attn.in_proj": 0.9999906904995441,
+ "model.base_model.layers20.attn.dense": 1.0000156918540597,
+ "model.base_model.layers20.ff_ln": 0.9999907249584794,
+ "model.base_model.layers20.ff_sub_layer.wg": 0.9999940283596516,
+ "model.base_model.layers20.ff_sub_layer.a": 1.0000511296093464,
+ "model.base_model.layers20.ff_sub_layer.w1": 0.9999824520200491,
+ "model.base_model.layers20.ff_sub_layer.w2": 1.0000235689803958,
+ "model.base_model.layers20.ff_sub_layer": 1.0000235689803958,
+ "model.base_model.layers21.ln": 1.0000102324411273,
+ "model.base_model.layers21.attn.in_proj.query": 0.9999887580052018,
+ "model.base_model.layers21.attn.in_proj.key": 0.9999670209363103,
+ "model.base_model.layers21.attn.in_proj.value": 1.0000033304095268,
+ "model.base_model.layers21.attn.in_proj": 1.0000033304095268,
+ "model.base_model.layers21.attn.dense": 1.0000032112002373,
+ "model.base_model.layers21.ff_ln": 0.9999745050445199,
+ "model.base_model.layers21.ff_sub_layer.wg": 1.0000191712751985,
+ "model.base_model.layers21.ff_sub_layer.a": 0.9999397536739707,
+ "model.base_model.layers21.ff_sub_layer.w1": 0.999976921826601,
+ "model.base_model.layers21.ff_sub_layer.w2": 1.0000147921964526,
+ "model.base_model.layers21.ff_sub_layer": 1.0000147921964526,
+ "model.base_model.layers22.ln": 1.0000217780470848,
+ "model.base_model.layers22.attn.in_proj.query": 1.0000304505228996,
+ "model.base_model.layers22.attn.in_proj.key": 1.0000391714274883,
+ "model.base_model.layers22.attn.in_proj.value": 0.9999985108152032,
+ "model.base_model.layers22.attn.in_proj": 0.9999985108152032,
+ "model.base_model.layers22.attn.dense": 0.9999870825558901,
+ "model.base_model.layers22.ff_ln": 1.000030828639865,
+ "model.base_model.layers22.ff_sub_layer.wg": 0.9999970505014062,
+ "model.base_model.layers22.ff_sub_layer.a": 0.9999665170907974,
+ "model.base_model.layers22.ff_sub_layer.w1": 1.0000078286975622,
+ "model.base_model.layers22.ff_sub_layer.w2": 0.9999847915023565,
+ "model.base_model.layers22.ff_sub_layer": 0.9999847915023565,
+ "model.base_model.layers23.ln": 1.000011671334505,
+ "model.base_model.layers23.attn.in_proj.query": 0,
+ "model.base_model.layers23.attn.in_proj.key": 0,
+ "model.base_model.layers23.attn.in_proj.value": 1.0000066049396992,
+ "model.base_model.layers23.attn.in_proj": 1.0000066049396992,
+ "model.base_model.layers23.attn.dense": 1.000013079494238,
+ "model.base_model.layers23.ff_ln": 0.9999612653627992,
+ "model.base_model.layers23.ff_sub_layer.wg": 1.000030379742384,
+ "model.base_model.layers23.ff_sub_layer.a": 1.0000113025307655,
+ "model.base_model.layers23.ff_sub_layer.w1": 1.0000112522393465,
+ "model.base_model.layers23.ff_sub_layer.w2": 0.9999910769984126,
+ "model.base_model.layers23.ff_sub_layer": 0.9999910769984126,
+ "model.base_model.layers24.ln": 0.9999892562627792,
+ "model.base_model.layers24.attn.in_proj.query": 0,
+ "model.base_model.layers24.attn.in_proj.key": 0,
+ "model.base_model.layers24.attn.in_proj.value": 1.0000211102887988,
+ "model.base_model.layers24.attn.in_proj": 1.0000211102887988,
+ "model.base_model.layers24.attn.dense": 1.0000214129686356,
+ "model.base_model.layers24.ff_ln": 1.000025992281735,
+ "model.base_model.layers24.ff_sub_layer.wg": 1.00003114156425,
+ "model.base_model.layers24.ff_sub_layer.a": 0.999997915700078,
+ "model.base_model.layers24.ff_sub_layer.w1": 0.9999974723905325,
+ "model.base_model.layers24.ff_sub_layer.w2": 0.9999948246404529,
+ "model.base_model.layers24.ff_sub_layer": 0.9999948246404529,
+ "model.base_model.layers25.ln": 1.0000123046338558,
+ "model.base_model.layers25.attn.in_proj.query": 1.000020838342607,
+ "model.base_model.layers25.attn.in_proj.key": 1.000002938322723,
+ "model.base_model.layers25.attn.in_proj.value": 1.000010005198419,
+ "model.base_model.layers25.attn.in_proj": 1.000010005198419,
+ "model.base_model.layers25.attn.dense": 1.0000402312725782,
+ "model.base_model.layers25.ff_ln": 0.9999652067199349,
+ "model.base_model.layers25.ff_sub_layer.wg": 0.9999842857941985,
+ "model.base_model.layers25.ff_sub_layer.a": 1.0000258618965745,
+ "model.base_model.layers25.ff_sub_layer.w1": 1.0000001015141606,
+ "model.base_model.layers25.ff_sub_layer.w2": 0.9999574366956949,
+ "model.base_model.layers25.ff_sub_layer": 0.9999574366956949,
+ "model.base_model.layers26.ln": 0.9999806303530931,
+ "model.base_model.layers26.attn.in_proj.query": 0,
+ "model.base_model.layers26.attn.in_proj.key": 0,
+ "model.base_model.layers26.attn.in_proj.value": 0.9999995995312929,
+ "model.base_model.layers26.attn.in_proj": 0.9999995995312929,
+ "model.base_model.layers26.attn.dense": 1.000026204623282,
+ "model.base_model.layers26.ff_ln": 1.0000002216547728,
+ "model.base_model.layers26.ff_sub_layer.wg": 1.000026123598218,
+ "model.base_model.layers26.ff_sub_layer.a": 1.0000658445060253,
+ "model.base_model.layers26.ff_sub_layer.w1": 0.9999972693622112,
+ "model.base_model.layers26.ff_sub_layer.w2": 0.999981164932251,
+ "model.base_model.layers26.ff_sub_layer": 0.999981164932251,
+ "model.base_model.layers27.ln": 0.9999749213457108,
+ "model.base_model.layers27.attn.in_proj.query": 0,
+ "model.base_model.layers27.attn.in_proj.key": 0,
+ "model.base_model.layers27.attn.in_proj.value": 0.99999382160604,
+ "model.base_model.layers27.attn.in_proj": 0.99999382160604,
+ "model.base_model.layers27.attn.dense": 1.000030618160963,
+ "model.base_model.layers27.ff_ln": 1.0000327732414007,
+ "model.base_model.layers27.ff_sub_layer.wg": 1.0000053877010942,
+ "model.base_model.layers27.ff_sub_layer.a": 0.9999731816351414,
+ "model.base_model.layers27.ff_sub_layer.w1": 0.9999862620607018,
+ "model.base_model.layers27.ff_sub_layer.w2": 1.000004656612873,
+ "model.base_model.layers27.ff_sub_layer": 1.000004656612873,
+ "model.base_model.layers28.ln": 1.0000333590433002,
+ "model.base_model.layers28.attn.in_proj.query": 0.9999866727739573,
+ "model.base_model.layers28.attn.in_proj.key": 0.9999993778765202,
+ "model.base_model.layers28.attn.in_proj.value": 0.9999468578025699,
+ "model.base_model.layers28.attn.in_proj": 0.9999468578025699,
+ "model.base_model.layers28.attn.dense": 1.0000711474567652,
+ "model.base_model.layers28.ff_ln": 1.0000074906274676,
+ "model.base_model.layers28.ff_sub_layer.wg": 0.9999890364706516,
+ "model.base_model.layers28.ff_sub_layer.a": 0.9999930933117867,
+ "model.base_model.layers28.ff_sub_layer.w1": 0.99996502045542,
+ "model.base_model.layers28.ff_sub_layer.w2": 0.9999485462903976,
+ "model.base_model.layers28.ff_sub_layer": 0.9999485462903976,
+ "model.base_model.layers29.ln": 0.9999927263706923,
+ "model.base_model.layers29.attn.in_proj.query": 0,
+ "model.base_model.layers29.attn.in_proj.key": 0,
+ "model.base_model.layers29.attn.in_proj.value": 0.999944906681776,
+ "model.base_model.layers29.attn.in_proj": 0.999944906681776,
+ "model.base_model.layers29.attn.dense": 0.9999553980305791,
+ "model.base_model.layers29.ff_ln": 0.9999844022095203,
+ "model.base_model.layers29.ff_sub_layer.wg": 1.0000193798914552,
+ "model.base_model.layers29.ff_sub_layer.a": 0.999941049143672,
+ "model.base_model.layers29.ff_sub_layer.w1": 1.000042955391109,
+ "model.base_model.layers29.ff_sub_layer.w2": 0.9999570790678263,
+ "model.base_model.layers29.ff_sub_layer": 0.9999570790678263,
+ "model.base_model.layers30.ln": 0.9999837139621377,
+ "model.base_model.layers30.attn.in_proj.query": 0.9999870667234063,
+ "model.base_model.layers30.attn.in_proj.key": 0.9999911524355412,
+ "model.base_model.layers30.attn.in_proj.value": 1.0000267829746008,
+ "model.base_model.layers30.attn.in_proj": 1.0000267829746008,
+ "model.base_model.layers30.attn.dense": 1.000015159137547,
+ "model.base_model.layers30.ff_ln": 0.9999454841017723,
+ "model.base_model.layers30.ff_sub_layer.wg": 1.0000337418168783,
+ "model.base_model.layers30.ff_sub_layer.a": 1.0000011585652828,
+ "model.base_model.layers30.ff_sub_layer.w1": 0.9999842243269086,
+ "model.base_model.layers30.ff_sub_layer.w2": 1.0000219950452447,
+ "model.base_model.layers30.ff_sub_layer": 1.0000219950452447,
+ "model.base_model.layers31.ln": 1.0000553196296096,
+ "model.base_model.layers31.attn.in_proj.query": 1.0000495202839375,
+ "model.base_model.layers31.attn.in_proj.key": 0.9999600946903229,
+ "model.base_model.layers31.attn.in_proj.value": 0.9999898271635175,
+ "model.base_model.layers31.attn.in_proj": 0.9999898271635175,
+ "model.base_model.layers31.attn.dense": 1.0000101793557405,
+ "model.base_model.layers31.ff_ln": 0.9999493975192308,
+ "model.base_model.layers31.ff_sub_layer.wg": 1.0000120643526316,
+ "model.base_model.layers31.ff_sub_layer.a": 0.9999757166951895,
+ "model.base_model.layers31.ff_sub_layer.w1": 0.9999987371265888,
+ "model.base_model.layers31.ff_sub_layer.w2": 1.0000365702435374,
+ "model.base_model.layers31.ff_sub_layer": 1.0000365702435374,
+ "model.base_model.layers32.ln": 0.9999954663217068,
+ "model.base_model.layers32.attn.in_proj.query": 0.9999834299087524,
+ "model.base_model.layers32.attn.in_proj.key": 0.9999754223972559,
+ "model.base_model.layers32.attn.in_proj.value": 1.0000332547351718,
+ "model.base_model.layers32.attn.in_proj": 1.0000332547351718,
+ "model.base_model.layers32.attn.dense": 1.0000128587707877,
+ "model.base_model.layers32.ff_ln": 0.9999977294355631,
+ "model.base_model.layers32.ff_sub_layer.wg": 1.000013922341168,
+ "model.base_model.layers32.ff_sub_layer.a": 0.9999843658879399,
+ "model.base_model.layers32.ff_sub_layer.w1": 1.00000412017107,
+ "model.base_model.layers32.ff_sub_layer.w2": 1.0000331467017531,
+ "model.base_model.layers32.ff_sub_layer": 1.0000331467017531,
+ "model.base_model.layers33.ln": 0.9999776463955641,
+ "model.base_model.layers33.attn.in_proj.query": 0.9999711429700255,
+ "model.base_model.layers33.attn.in_proj.key": 1.000004799105227,
+ "model.base_model.layers33.attn.in_proj.value": 1.000003564171493,
+ "model.base_model.layers33.attn.in_proj": 1.000003564171493,
+ "model.base_model.layers33.attn.dense": 1.0000118184834719,
+ "model.base_model.layers33.ff_ln": 0.9999954793602228,
+ "model.base_model.layers33.ff_sub_layer.wg": 1.000005860812962,
+ "model.base_model.layers33.ff_sub_layer.a": 1.0000032559037209,
+ "model.base_model.layers33.ff_sub_layer.w1": 1.0000015143305063,
+ "model.base_model.layers33.ff_sub_layer.w2": 0.9999763583764434,
+ "model.base_model.layers33.ff_sub_layer": 0.9999763583764434,
+ "model.base_model.layers34.ln": 0.9999581193551421,
+ "model.base_model.layers34.attn.in_proj.query": 1.000028538517654,
+ "model.base_model.layers34.attn.in_proj.key": 0.9999862778931856,
+ "model.base_model.layers34.attn.in_proj.value": 0.9999695559963584,
+ "model.base_model.layers34.attn.in_proj": 0.9999695559963584,
+ "model.base_model.layers34.attn.dense": 1.0000354330986738,
+ "model.base_model.layers34.ff_ln": 0.9999958276748657,
+ "model.base_model.layers34.ff_sub_layer.wg": 0.9999709380790591,
+ "model.base_model.layers34.ff_sub_layer.a": 0.9999592872336507,
+ "model.base_model.layers34.ff_sub_layer.w1": 0.9999475320801139,
+ "model.base_model.layers34.ff_sub_layer.w2": 1.0000180834904313,
+ "model.base_model.layers34.ff_sub_layer": 1.0000180834904313,
+ "model.base_model.layers35.ln": 1.0000238129869103,
+ "model.base_model.layers35.attn.in_proj.query": 0.9999912530183792,
+ "model.base_model.layers35.attn.in_proj.key": 0.9999760650098324,
+ "model.base_model.layers35.attn.in_proj.value": 1.0000254157930613,
+ "model.base_model.layers35.attn.in_proj": 1.0000254157930613,
+ "model.base_model.layers35.attn.dense": 1.000043666921556,
+ "model.base_model.layers35.ff_ln": 0.9999769097194076,
+ "model.base_model.layers35.ff_sub_layer.wg": 1.000048492103815,
+ "model.base_model.layers35.ff_sub_layer.a": 1.000031128525734,
+ "model.base_model.layers35.ff_sub_layer.w1": 1.0000278418883681,
+ "model.base_model.layers35.ff_sub_layer.w2": 0.9999823272228241,
+ "model.base_model.layers35.ff_sub_layer": 0.9999823272228241,
+ "model.base_model.layers36.ln": 1.0000083968043327,
+ "model.base_model.layers36.attn.in_proj.query": 0.9999890839681029,
+ "model.base_model.layers36.attn.in_proj.key": 1.0000366857275367,
+ "model.base_model.layers36.attn.in_proj.value": 0.9999926909804344,
+ "model.base_model.layers36.attn.in_proj": 0.9999926909804344,
+ "model.base_model.layers36.attn.dense": 0.9999915473163128,
+ "model.base_model.layers36.ff_ln": 1.0000037085264921,
+ "model.base_model.layers36.ff_sub_layer.wg": 0.9999989084899426,
+ "model.base_model.layers36.ff_sub_layer.a": 1.0000268882140517,
+ "model.base_model.layers36.ff_sub_layer.w1": 1.0000533871352673,
+ "model.base_model.layers36.ff_sub_layer.w2": 1.0000037532299757,
+ "model.base_model.layers36.ff_sub_layer": 1.0000037532299757,
+ "model.base_model.layers37.ln": 0.9999704547226429,
+ "model.base_model.layers37.attn.in_proj.query": 0.9999823262915015,
+ "model.base_model.layers37.attn.in_proj.key": 0.9999811919406056,
+ "model.base_model.layers37.attn.in_proj.value": 0.9999743653461337,
+ "model.base_model.layers37.attn.in_proj": 0.9999743653461337,
+ "model.base_model.layers37.attn.dense": 1.0000066151842475,
+ "model.base_model.layers37.ff_ln": 0.9999751346185803,
+ "model.base_model.layers37.ff_sub_layer.wg": 0.999997915700078,
+ "model.base_model.layers37.ff_sub_layer.a": 1.0000371094793081,
+ "model.base_model.layers37.ff_sub_layer.w1": 1.0000042142346501,
+ "model.base_model.layers37.ff_sub_layer.w2": 1.000011881813407,
+ "model.base_model.layers37.ff_sub_layer": 1.000011881813407,
+ "model.base_model.layers38.ln": 1.0000276556238532,
+ "model.base_model.layers38.attn.in_proj.query": 0.9999751839786768,
+ "model.base_model.layers38.attn.in_proj.key": 0.9999913470819592,
+ "model.base_model.layers38.attn.in_proj.value": 0.9999802326783538,
+ "model.base_model.layers38.attn.in_proj": 0.9999802326783538,
+ "model.base_model.layers38.attn.dense": 1.0000138999894261,
+ "model.base_model.layers38.ff_ln": 0.9999972302466631,
+ "model.base_model.layers38.ff_sub_layer.wg": 0.9999596020206809,
+ "model.base_model.layers38.ff_sub_layer.a": 0.9999620048329234,
+ "model.base_model.layers38.ff_sub_layer.w1": 1.0000002244487405,
+ "model.base_model.layers38.ff_sub_layer.w2": 1.0000019194558263,
+ "model.base_model.layers38.ff_sub_layer": 1.0000019194558263,
+ "model.base_model.layers39.ln": 0.9999960288405418,
+ "model.base_model.layers39.attn.in_proj.query": 0.9999866196885705,
+ "model.base_model.layers39.attn.in_proj.key": 1.000024433247745,
+ "model.base_model.layers39.attn.in_proj.value": 0.9999685110524297,
+ "model.base_model.layers39.attn.in_proj": 0.9999685110524297,
+ "model.base_model.layers39.attn.dense": 0.9999954961240292,
+ "model.base_model.layers39.ff_ln": 1.0000354265794158,
+ "model.base_model.layers39.ff_sub_layer.wg": 1.0000474276021123,
+ "model.base_model.layers39.ff_sub_layer.a": 1.0000188555568457,
+ "model.base_model.layers39.ff_sub_layer.w1": 0.9999263407662511,
+ "model.base_model.layers39.ff_sub_layer.w2": 0.999988155439496,
+ "model.base_model.layers39.ff_sub_layer": 0.999988155439496,
+ "model.base_model.dec_norm": 1.0000296486541629,
+ "model.base_model": 0,
+ "model.head": 1.0000323532149196
+ }
+}
\ No newline at end of file