Replies: 1 comment
-
The variables in |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
one thing i've been doing is using the jax backend with
vmap
over a model'sstateless_call
to compose small models into bigger models with different batching sizesa nuance of this is that as a return we get a vectorised
non_trainable_variables
which, depending on the use case of course, probably needs to be reduced to a single set of variables for the next call...so depending on the effective "type" of the
non_trainable_variable
this reduction might be done in different ways; e.g. i find myself commonly reducing batch norm stats with code like ...( though it's kinda clunky since it ends up being a mean of a mean kinda situation )
i can't do this for all the
non_trainable_variables
though if i have something like dropout, since it''ll include things like the older styleuint32
RNG seeds...i think i can infer it (safely?) by checking the
KerasVariable.path
but that seems dangerous...am i missing something around typing for this variables beyond the
path
from the model itself?Beta Was this translation helpful? Give feedback.
All reactions