M
Modular5mo ago
Hylke

How to define nested dict function

(This is a shortened cross post of my stackoverflow question). How do I define a function in Mojo where the argument or return value is a nested structure, like a dictionary. But where the exact structure is not known at compile time because, for example, it is read from disk. For example, how do I type annotate params in the function neural_network to represent a nested dict to keep the model parameters (i.e., a pytree).
network_params = {
'linear1': {'weights': Tensor(...), 'bias': 1.0},
...
}

fn neural_network(params: Dict, x: Tensor) -> Float64:
x = params['linear1']['weights'] @ x + params['linear1']['bias']
...
return x

x_input = Tensor(...)
y = neural_network(network_params, x_input)
network_params = {
'linear1': {'weights': Tensor(...), 'bias': 1.0},
...
}

fn neural_network(params: Dict, x: Tensor) -> Float64:
x = params['linear1']['weights'] @ x + params['linear1']['bias']
...
return x

x_input = Tensor(...)
y = neural_network(network_params, x_input)
Thanks so much!
Stack Overflow
How to define a nested dictionary argument of arbitrary depth
How do I define a function in Mojo where the argument is a nested structure, like a dictionary. But where the exact structure is not known at compile time because, for example, it is read from disk...
3 Replies
capt_falamer
capt_falamer4mo ago
Probably not the best way to do what you're looking for but see the 2 examples below. From example you have though, makes me wonder if you want to use a nested dict or just store each type in its own array.
@value
struct Node:
var dict: Dict[String, Variant[String,Node]]

fn __setitem__(inout self, key: String, val:Variant[String,Node]):
self.dict[key] = val

fn __getitem__(self, key: String) raises -> Variant[String,Node]:
if key in self.dict:
return self.dict[key]
else:
return Variant[String,Node](str('Does not exist'))

alias str_int = Variant[String,Int,StringLiteral]
@value
struct Node2:
var dict: Dict[String, Node2]
var value: str_int

fn __setitem__(inout self, key: String, n: Node2):
self.dict[key] = n

fn __getitem__(self, key: String) raises -> Node2:
if key in self.dict:
return self.dict[key]
else:
return Node2(Dict[String, Node2](), '')

def main():
n = Node(Dict[String, Variant[String,Node]]())
var a1 = Variant[String,Node](str('123'))
var a2 = Variant[String,Node](str('Pike'))
var a3 = Variant[String,Node](str('NY'))
var add = Node(Dict[String, Variant[String,Node]]())
add['street'] = a1
add['city'] = a2
add['state'] = a3
n['Address'] = add
print(n['Address'][Node]['street'][String])
print(n['Address'][Node]['city'][String])
print(n['Address'][Node]['state'][String])

n2 = Node2(Dict[String, Node2](), str_int(''))

n2['user'] = Node2(Dict[String, Node2](), str_int('Allise'))
print(n2['user'].value[StringLiteral])

n2['address'] = Node2(Dict[String, Node2](), str_int(''))
n2['address']['street'] = Node2(Dict[String, Node2](), str_int(123))
n2['address']['city'] = Node2(Dict[String, Node2](), str_int('Pike'))
n2['address']['state'] = Node2(Dict[String, Node2](), str_int('NY'))
@value
struct Node:
var dict: Dict[String, Variant[String,Node]]

fn __setitem__(inout self, key: String, val:Variant[String,Node]):
self.dict[key] = val

fn __getitem__(self, key: String) raises -> Variant[String,Node]:
if key in self.dict:
return self.dict[key]
else:
return Variant[String,Node](str('Does not exist'))

alias str_int = Variant[String,Int,StringLiteral]
@value
struct Node2:
var dict: Dict[String, Node2]
var value: str_int

fn __setitem__(inout self, key: String, n: Node2):
self.dict[key] = n

fn __getitem__(self, key: String) raises -> Node2:
if key in self.dict:
return self.dict[key]
else:
return Node2(Dict[String, Node2](), '')

def main():
n = Node(Dict[String, Variant[String,Node]]())
var a1 = Variant[String,Node](str('123'))
var a2 = Variant[String,Node](str('Pike'))
var a3 = Variant[String,Node](str('NY'))
var add = Node(Dict[String, Variant[String,Node]]())
add['street'] = a1
add['city'] = a2
add['state'] = a3
n['Address'] = add
print(n['Address'][Node]['street'][String])
print(n['Address'][Node]['city'][String])
print(n['Address'][Node]['state'][String])

n2 = Node2(Dict[String, Node2](), str_int(''))

n2['user'] = Node2(Dict[String, Node2](), str_int('Allise'))
print(n2['user'].value[StringLiteral])

n2['address'] = Node2(Dict[String, Node2](), str_int(''))
n2['address']['street'] = Node2(Dict[String, Node2](), str_int(123))
n2['address']['city'] = Node2(Dict[String, Node2](), str_int('Pike'))
n2['address']['state'] = Node2(Dict[String, Node2](), str_int('NY'))
capt_falamer
capt_falamer4mo ago
Also if you are looking at parsing (or better more fleshed out code) I recommend checking out https://www.modular.com/modverse/modverse-42 specifically Pholmola's project
Modverse #42: Magic is the best way to build with MAX and Mojo
Welcome to Modverse #42, covering blogs, videos, tutorials, community projects, MAX, and Mojo!
Hylke
HylkeOP4mo ago
Thanks for the pointers, that's a big help!

Did you find this page helpful?