HuggingFaceのライブラリの雑多な調査記録
Transformers
Q. model = transformers.CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
でPyTorch, or TensorFlow or Jax or Safetensorsのモデルのどれがロードされるの?
- (cite): https://huggingface.co/openai/clip-vit-large-patch14/blob/main/README.md
- README.mdによると 次のようにしてpretrainedモデルのロードする。
from PIL import Image import requests from transformers import CLIPProcessor, CLIPModel model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
- (cite): https://huggingface.co/openai/clip-vit-large-patch14/tree/main
- しかしながら、HuggingFace Hub Modelsで配布されているファイル PyTorch(pytorch_model.bin), or TensorFlow(tf_model.h5) or Jax(flax_model.msgpack) or Safetensors(model.safetensors) のどれがロードされるかわからない
A. from_tf=True
を渡せばTensorflow、from_flax =True
を渡せばJax、use_safetensors=True
を渡せばSafetensors、デフォルトはPyTorch
class CLIPModel(CLIPPreTrainedModel):
(cite): https://github.com/huggingface/transformers/blob/91d155ea92da372b319a79dd4eef69533ee15170/src/transformers/models/clip/modeling_clip.py#L938- ↓
class CLIPPreTrainedModel(PreTrainedModel):
(cite): https://github.com/huggingface/transformers/blob/91d155ea92da372b319a79dd4eef69533ee15170/src/transformers/models/clip/modeling_clip.py#L399C1-L399C44- ↓
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
(cite): https://github.com/huggingface/transformers/blob/91d155ea92da372b319a79dd4eef69533ee15170/src/transformers/modeling_utils.py#L1223- ↓
def from_pretrained(
(cite): https://github.com/huggingface/transformers/blob/91d155ea92da372b319a79dd4eef69533ee15170/src/transformers/modeling_utils.py#L2724- コメント:
from_tf
(cite): https://github.com/huggingface/transformers/blob/91d155ea92da372b319a79dd4eef69533ee15170/src/transformers/modeling_utils.py#L2793C13-L2793C20from_flax
(cite): https://github.com/huggingface/transformers/blob/91d155ea92da372b319a79dd4eef69533ee15170/src/transformers/modeling_utils.py#L2796use_safetensors
(cite): https://github.com/huggingface/transformers/blob/91d155ea92da372b319a79dd4eef69533ee15170/src/transformers/modeling_utils.py#L2907C13-L2907C29
- コード:
- コメント:
参考までにhuggingface/transformersが想定するファイル名は下記の通り。(cite): src/transformers/utils/__init__.py#L220C1-L234C35
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" TF2_WEIGHTS_NAME = "tf_model.h5" TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" TF_WEIGHTS_NAME = "model.ckpt" FLAX_WEIGHTS_NAME = "flax_model.msgpack" FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json" SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"